diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index a5eb9185..cea7f257 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -54,7 +54,7 @@ def main( "top_p": work_request.top_p, "temperature": work_request.temperature, "seed": work_request.seed, - # "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why? + # "stop": ["\nUser:", "\nAssistant:"], # TODO: make this a bit more workable because it's mutliple tokens }, }, stream=True, @@ -64,6 +64,7 @@ def main( response.raise_for_status() except requests.HTTPError: logger.exception("Failed to get response from inference server") + logger.error(f"Response: {response.text}") return client = sseclient.SSEClient(response) diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 764d8acf..a638e55c 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -13,7 +13,7 @@ class WorkRequest(pydantic.BaseModel): conversation: protocol.Conversation = pydantic.Field(..., repr=False) model_name: str = "distilgpt2" max_new_tokens: int = 100 - seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1)) + seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**31 - 1)) do_sample: bool = True top_k: int = 50 top_p: float = 0.9