import collections from typing import Literal import interface class TokenBuffer: def __init__(self, stop_sequences: list[str]) -> None: self.stop_sequences = stop_sequences self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0) self.tokens = collections.deque() self.token_lens = collections.deque() self.total_len = 0 def add(self, token: interface.Token): self.tokens.append(token) self.token_lens.append(len(token)) self.total_len += len(token) while True: if not self.tokens: break head_len = self.token_lens[0] if self.total_len - head_len >= self.longest_stop_len: token = self.tokens.popleft() self.token_lens.popleft() self.total_len -= head_len yield token else: break def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]): if reason == "stop_sequence": end_sequence = "" while self.tokens: end_sequence = self.tokens.pop().text + end_sequence if end_sequence in self.stop_sequences: break yield from self.tokens else: yield from self.tokens