mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
032a748ba5
* Added - Basic functions to export trees for users, export-ready trees and specific tree ids to files * Added print to logger by default for no file specified * linting to remove extra imports * Added cli for exporting trees which are ready to export Fixed some accidental removal Updated message lookup to use dict for better perf * removed unused imports * changed export flag for including deleted prompts back to include_deleted for better understandability * Use native collection types list, tuple, dict * pre-commit fix Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
from collections import defaultdict
|
|
from typing import Optional, TextIO
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
from oasst_backend.models import Message
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class ExportMessageNode(BaseModel):
|
|
message_id: str
|
|
parent_id: Optional[str]
|
|
text: Optional[str]
|
|
role: str
|
|
review_count: Optional[int]
|
|
rank: Optional[int]
|
|
replies: Optional[list[ExportMessageNode]]
|
|
|
|
@classmethod
|
|
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
|
return cls(
|
|
message_id=str(message.id),
|
|
parent_id=str(message.parent_id) if message.parent_id else None,
|
|
text=str(message.payload.payload.text),
|
|
role=message.role,
|
|
review_count=message.review_count,
|
|
rank=message.rank,
|
|
)
|
|
|
|
|
|
class ExportMessageTree(BaseModel):
|
|
message_tree_id: str
|
|
replies: Optional[ExportMessageNode]
|
|
|
|
|
|
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
|
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
|
|
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
|
|
|
|
message_parents = defaultdict(list)
|
|
for message in export_tree_data:
|
|
message_parents[message.parent_id].append(message)
|
|
|
|
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
|
|
children = message_parents[parent]
|
|
tree.replies = children
|
|
|
|
for idx, child in enumerate(tree.replies):
|
|
build_tree(tree.replies[idx], child.message_id, messages)
|
|
|
|
build_tree(export_tree, None, export_tree_data)
|
|
|
|
return export_tree
|
|
|
|
|
|
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
|
|
|
|
out_buff: TextIO
|
|
if use_compression:
|
|
out_buff = gzip.open(file, "wt", encoding="UTF-8")
|
|
else:
|
|
out_buff = open(file, "wt", encoding="UTF-8")
|
|
|
|
with out_buff as f:
|
|
for tree in trees:
|
|
file_data = jsonable_encoder(tree)
|
|
json.dump(file_data, f)
|
|
f.write("\n")
|