mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
HF inference fixes
This commit is contained in:
@@ -54,7 +54,7 @@ def main(
|
||||
"top_p": work_request.top_p,
|
||||
"temperature": work_request.temperature,
|
||||
"seed": work_request.seed,
|
||||
# "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why?
|
||||
# "stop": ["\nUser:", "\nAssistant:"], # TODO: make this a bit more workable because it's mutliple tokens
|
||||
},
|
||||
},
|
||||
stream=True,
|
||||
@@ -64,6 +64,7 @@ def main(
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError:
|
||||
logger.exception("Failed to get response from inference server")
|
||||
logger.error(f"Response: {response.text}")
|
||||
return
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
|
||||
@@ -13,7 +13,7 @@ class WorkRequest(pydantic.BaseModel):
|
||||
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1))
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**31 - 1))
|
||||
do_sample: bool = True
|
||||
top_k: int = 50
|
||||
top_p: float = 0.9
|
||||
|
||||
Reference in New Issue
Block a user