Files
Open-Assistant/inference/worker/utils.py
T
2023-02-09 23:46:44 +01:00

41 lines
1.3 KiB
Python

import collections
from typing import Literal
import interface
class TokenBuffer:
def __init__(self, stop_sequences: list[str]) -> None:
self.stop_sequences = stop_sequences
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0)
self.tokens = collections.deque()
self.token_lens = collections.deque()
self.total_len = 0
def add(self, token: interface.Token):
self.tokens.append(token)
self.token_lens.append(len(token))
self.total_len += len(token)
while True:
if not self.tokens:
break
head_len = self.token_lens[0]
if self.total_len - head_len >= self.longest_stop_len:
token = self.tokens.popleft()
self.token_lens.popleft()
self.total_len -= head_len
yield token
else:
break
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]):
if reason == "stop_sequence":
end_sequence = ""
while self.tokens:
end_sequence = self.tokens.pop().text + end_sequence
if end_sequence in self.stop_sequences:
break
yield from self.tokens
else:
yield from self.tokens