Files
Open-Assistant/inference/worker/utils.py
T
Yannic Kilcher 90c3d5640e Added database to inference server (#1446)
* added db for inference

* fixed dockerfiles for inference
2023-02-10 22:51:35 +01:00

63 lines
2.0 KiB
Python

import collections
import random
import time
from typing import Literal
import interface
import requests
from loguru import logger
class TokenBuffer:
def __init__(self, stop_sequences: list[str]) -> None:
self.stop_sequences = stop_sequences
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0)
self.tokens = collections.deque()
self.token_lens = collections.deque()
self.total_len = 0
def add(self, token: interface.Token):
self.tokens.append(token)
self.token_lens.append(len(token))
self.total_len += len(token)
while True:
if not self.tokens:
break
head_len = self.token_lens[0]
if self.total_len - head_len >= self.longest_stop_len:
token = self.tokens.popleft()
self.token_lens.popleft()
self.total_len -= head_len
yield token
else:
break
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]):
if reason == "stop_sequence":
end_sequence = ""
while self.tokens:
end_sequence = self.tokens.pop().text + end_sequence
if end_sequence in self.stop_sequences:
break
yield from self.tokens
else:
yield from self.tokens
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
health_url = f"{inference_server_url}/health"
time_limit = time.time() + timeout
while True:
try:
response = requests.get(health_url)
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
raise
sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
time.sleep(sleep_duration)
else:
logger.info("Inference server is ready")
break