import collections import random import time from typing import Literal import interface import requests from loguru import logger class TokenBuffer: def __init__(self, stop_sequences: list[str]) -> None: self.stop_sequences = stop_sequences self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0) self.tokens = collections.deque() self.token_lens = collections.deque() self.total_len = 0 def add(self, token: interface.Token): self.tokens.append(token) self.token_lens.append(len(token)) self.total_len += len(token) while True: if not self.tokens: break head_len = self.token_lens[0] if self.total_len - head_len >= self.longest_stop_len: token = self.tokens.popleft() self.token_lens.popleft() self.total_len -= head_len yield token else: break def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]): if reason == "stop_sequence": end_sequence = "" while self.tokens: end_sequence = self.tokens.pop().text + end_sequence if end_sequence in self.stop_sequences: break yield from self.tokens else: yield from self.tokens def wait_for_inference_server(inference_server_url: str, timeout: int = 600): health_url = f"{inference_server_url}/health" time_limit = time.time() + timeout while True: try: response = requests.get(health_url) response.raise_for_status() except (requests.HTTPError, requests.ConnectionError): if time.time() > time_limit: raise sleep_duration = random.uniform(0, 10) logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds") time.sleep(sleep_duration) else: logger.info("Inference server is ready") break