Files
Open-Assistant/backend/oasst_backend/utils/tree_export.py
T
Andreas Köpf 1153734edb Allow to export filtered by tree state and/or only prompts (#1314)
* allow to export filtered by tree state and/or only prompts

* formatting

* add option to filter by lang code

* remove old export code

* export as list in only-deleted case
2023-02-07 22:17:36 +01:00

120 lines
3.8 KiB
Python

from __future__ import annotations
import contextlib
import gzip
import json
import sys
from collections import defaultdict
from typing import Iterable, Optional, TextIO
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from oasst_backend.models import Message
from oasst_backend.models.message_tree_state import State as TreeState
from pydantic import BaseModel
class ExportMessageNode(BaseModel):
message_id: str
parent_id: str | None
text: str
role: str
lang: str | None
review_count: int | None
rank: int | None
synthetic: bool | None
model_name: str | None
emojis: dict[str, int] | None
replies: list[ExportMessageNode] | None
@staticmethod
def prep_message_export(message: Message) -> ExportMessageNode:
return ExportMessageNode(
message_id=str(message.id),
parent_id=str(message.parent_id) if message.parent_id else None,
text=str(message.payload.payload.text),
role=message.role,
lang=message.lang,
review_count=message.review_count,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,
rank=message.rank,
)
class ExportMessageTree(BaseModel):
message_tree_id: str
tree_state: Optional[str]
prompt: Optional[ExportMessageNode]
def build_export_tree(
message_tree_id: UUID, message_tree_state: TreeState, messages: list[Message]
) -> ExportMessageTree:
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
messages_by_parent = defaultdict(list)
for message in export_messages:
messages_by_parent[message.parent_id].append(message)
def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
node.replies = messages_by_parent[node.message_id]
node.replies.sort(key=lambda x: x.rank if x.rank is not None else float("inf"))
for child in node.replies:
assign_replies(child)
return node
prompt = assign_replies(messages_by_parent[None][0])
return ExportMessageTree(message_tree_id=str(message_tree_id), tree_state=message_tree_state, prompt=prompt)
# see https://stackoverflow.com/questions/17602878/how-to-handle-both-with-open-and-sys-stdout-nicely
@contextlib.contextmanager
def smart_open(filename: str = None) -> TextIO:
if filename and filename != "-":
fh = open(filename, "wt", encoding="UTF-8")
else:
fh = sys.stdout
try:
yield fh
finally:
if fh is not sys.stdout:
fh.close()
def write_trees_to_file(filename: str | None, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
out_buff: TextIO
if use_compression:
if not filename:
raise RuntimeError("File name must be specified when using compression.")
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
else:
out_buff = smart_open(filename)
with out_buff as f:
for tree in trees:
file_data = jsonable_encoder(tree, exclude_none=True)
json.dump(file_data, f)
f.write("\n")
def write_messages_to_file(filename: str | None, messages: Iterable[Message], use_compression: bool = True) -> None:
out_buff: TextIO
if use_compression:
if not filename:
raise RuntimeError("File name must be specified when using compression.")
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
else:
out_buff = smart_open(filename)
with out_buff as f:
for m in messages:
export_message = ExportMessageNode.prep_message_export(m)
file_data = jsonable_encoder(export_message, exclude_none=True)
json.dump(file_data, f)
f.write("\n")