mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
import collections
|
|
import random
|
|
import time
|
|
from typing import Literal
|
|
|
|
import interface
|
|
import requests
|
|
from loguru import logger
|
|
|
|
|
|
class TokenBuffer:
|
|
def __init__(self, stop_sequences: list[str]) -> None:
|
|
self.stop_sequences = stop_sequences
|
|
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0)
|
|
self.tokens = collections.deque()
|
|
self.token_lens = collections.deque()
|
|
self.total_len = 0
|
|
|
|
def add(self, token: interface.Token):
|
|
self.tokens.append(token)
|
|
self.token_lens.append(len(token))
|
|
self.total_len += len(token)
|
|
while True:
|
|
if not self.tokens:
|
|
break
|
|
head_len = self.token_lens[0]
|
|
if self.total_len - head_len >= self.longest_stop_len:
|
|
token = self.tokens.popleft()
|
|
self.token_lens.popleft()
|
|
self.total_len -= head_len
|
|
yield token
|
|
else:
|
|
break
|
|
|
|
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]):
|
|
if reason == "stop_sequence":
|
|
end_sequence = ""
|
|
while self.tokens:
|
|
end_sequence = self.tokens.pop().text + end_sequence
|
|
if end_sequence in self.stop_sequences:
|
|
break
|
|
yield from self.tokens
|
|
else:
|
|
yield from self.tokens
|
|
|
|
|
|
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
|
|
health_url = f"{inference_server_url}/health"
|
|
time_limit = time.time() + timeout
|
|
while True:
|
|
try:
|
|
response = requests.get(health_url)
|
|
response.raise_for_status()
|
|
except (requests.HTTPError, requests.ConnectionError):
|
|
if time.time() > time_limit:
|
|
raise
|
|
sleep_duration = random.uniform(0, 10)
|
|
logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds")
|
|
time.sleep(sleep_duration)
|
|
else:
|
|
logger.info("Inference server is ready")
|
|
break
|