mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
from typing import Literal
|
|
|
|
import pydantic
|
|
from oasst_shared.schemas import inference
|
|
|
|
|
|
class GenerateStreamParameters(pydantic.BaseModel):
|
|
max_new_tokens: int | None
|
|
do_sample: bool | None
|
|
top_k: int | None
|
|
top_p: float | None
|
|
temperature: float | None
|
|
repetition_penalty: float | None
|
|
seed: int | None
|
|
stop: list[str] = ["\nUser:", "\nAssistant:"] # TODO: make this a bit more workable because it's mutliple tokens
|
|
details: bool = True
|
|
|
|
@staticmethod
|
|
def from_work_request(work_request: inference.WorkRequest) -> "GenerateStreamParameters":
|
|
return GenerateStreamParameters(
|
|
max_new_tokens=work_request.max_new_tokens,
|
|
do_sample=work_request.do_sample,
|
|
top_k=work_request.top_k,
|
|
top_p=work_request.top_p,
|
|
temperature=work_request.temperature,
|
|
repetition_penalty=work_request.repetition_penalty,
|
|
seed=work_request.seed,
|
|
)
|
|
|
|
|
|
class Token(pydantic.BaseModel):
|
|
text: str
|
|
logprob: float
|
|
id: int
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.text)
|
|
|
|
def to_token_response(self) -> inference.TokenResponse:
|
|
return inference.TokenResponse(
|
|
text=self.text,
|
|
log_prob=self.logprob,
|
|
token_id=self.id,
|
|
)
|
|
|
|
|
|
class StreamDetails(pydantic.BaseModel):
|
|
generated_tokens: int
|
|
seed: int | None
|
|
finish_reason: Literal["length", "eos_token", "stop_sequence"]
|
|
|
|
|
|
class GenerateStreamResponse(pydantic.BaseModel):
|
|
token: Token
|
|
generated_text: str | None
|
|
details: StreamDetails | None
|