mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import collections
|
|
from typing import Literal
|
|
|
|
import interface
|
|
|
|
|
|
class TokenBuffer:
|
|
def __init__(self, stop_sequences: list[str]) -> None:
|
|
self.stop_sequences = stop_sequences
|
|
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0)
|
|
self.tokens = collections.deque()
|
|
self.token_lens = collections.deque()
|
|
self.total_len = 0
|
|
|
|
def add(self, token: interface.Token):
|
|
self.tokens.append(token)
|
|
self.token_lens.append(len(token))
|
|
self.total_len += len(token)
|
|
while True:
|
|
if not self.tokens:
|
|
break
|
|
head_len = self.token_lens[0]
|
|
if self.total_len - head_len >= self.longest_stop_len:
|
|
token = self.tokens.popleft()
|
|
self.token_lens.popleft()
|
|
self.total_len -= head_len
|
|
yield token
|
|
else:
|
|
break
|
|
|
|
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]):
|
|
if reason == "stop_sequence":
|
|
end_sequence = ""
|
|
while self.tokens:
|
|
end_sequence = self.tokens.pop().text + end_sequence
|
|
if end_sequence in self.stop_sequences:
|
|
break
|
|
yield from self.tokens
|
|
else:
|
|
yield from self.tokens
|