mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
44 lines
1.0 KiB
Python
44 lines
1.0 KiB
Python
import random
|
|
from typing import Literal
|
|
|
|
import pydantic
|
|
|
|
from . import protocol
|
|
|
|
|
|
class WorkerConfig(pydantic.BaseModel):
|
|
model_name: str = "distilgpt2"
|
|
|
|
@property
|
|
def compat_hash(self) -> str:
|
|
return f"{self.model_name}"
|
|
|
|
|
|
class WorkRequest(pydantic.BaseModel):
|
|
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
|
model_name: str = "distilgpt2"
|
|
max_new_tokens: int = 100
|
|
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 0xFFFF_FFFF_FFFF_FFFF - 1))
|
|
do_sample: bool = True
|
|
top_k: int = 50
|
|
top_p: float = 0.9
|
|
temperature: float = 1.0
|
|
repetition_penalty: float | None = None
|
|
|
|
|
|
class TokenResponse(pydantic.BaseModel):
|
|
text: str
|
|
log_prob: float
|
|
token_id: int
|
|
|
|
|
|
class GeneratedTextResponse(pydantic.BaseModel):
|
|
text: str
|
|
finish_reason: Literal["length", "eos_token", "stop_sequence"]
|
|
|
|
|
|
class WorkResponsePacket(pydantic.BaseModel):
|
|
token: TokenResponse | None = None
|
|
generated_text: GeneratedTextResponse | None = None
|
|
is_end: bool = False
|