mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
cf4c8e70ce
* Add export script
144 lines
4.1 KiB
Python
144 lines
4.1 KiB
Python
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
from loguru import logger
|
|
from oasst_backend.database import engine
|
|
from oasst_backend.models import Message, MessageTreeState, message_tree_state
|
|
from oasst_backend.utils import tree_export
|
|
from sqlmodel import Session, not_
|
|
|
|
|
|
def fetch_tree_ids(db: Session, ready_only: bool = False):
|
|
tree_qry = db.query(MessageTreeState)
|
|
|
|
if ready_only:
|
|
tree_qry = tree_qry.filter(MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT)
|
|
|
|
message_trees = tree_qry.all()
|
|
|
|
message_tree_ids = [tree.message_tree_id for tree in message_trees]
|
|
return message_tree_ids
|
|
|
|
|
|
def fetch_message_ids(
|
|
db: Session,
|
|
message_tree_id: Optional[UUID] = None,
|
|
user_id: Optional[UUID] = None,
|
|
include_deleted: bool = False,
|
|
deleted_only: bool = False,
|
|
) -> List[Message]:
|
|
qry = db.query(Message)
|
|
|
|
if message_tree_id:
|
|
qry = qry.filter(Message.message_tree_id == message_tree_id)
|
|
if user_id:
|
|
qry = qry.filter(Message.user_id == user_id)
|
|
if not include_deleted:
|
|
qry = qry.filter(not_(Message.deleted))
|
|
if deleted_only:
|
|
qry = qry.filter(Message.deleted)
|
|
|
|
return qry.all()
|
|
|
|
|
|
def export_trees(
|
|
db: Session,
|
|
export_file: Optional[Path] = None,
|
|
use_compression: bool = False,
|
|
ready_only: bool = False,
|
|
include_deleted: bool = False,
|
|
deleted_only: bool = False,
|
|
user_id: Optional[UUID] = None,
|
|
):
|
|
trees_to_export: List[tree_export.ExportMessageTree] = []
|
|
|
|
if user_id:
|
|
messages = fetch_message_ids(db, user_id=user_id, include_deleted=include_deleted, deleted_only=deleted_only)
|
|
message_tree_ids = [msg.message_tree_id for msg in messages]
|
|
else:
|
|
message_tree_ids = fetch_tree_ids(db, ready_only)
|
|
|
|
message_trees = [
|
|
fetch_message_ids(db, message_tree_id=tree_id, include_deleted=include_deleted, deleted_only=deleted_only)
|
|
for tree_id in message_tree_ids
|
|
]
|
|
|
|
for message_tree_id, message_tree in zip(message_tree_ids, message_trees):
|
|
trees_to_export.append(tree_export.build_export_tree(message_tree_id, message_tree))
|
|
|
|
if export_file:
|
|
tree_export.write_trees_to_file(export_file, trees_to_export, use_compression)
|
|
else:
|
|
sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2))
|
|
|
|
|
|
def validate_args(args):
|
|
if args.deleted_only:
|
|
args.include_deleted = True
|
|
|
|
args.use_compression = args.export_file is not None and ".gz" in args.export_file
|
|
|
|
if args.ready_only and args.user_id is not None:
|
|
raise ValueError("Cannot use --ready-only when specifying a user ID")
|
|
|
|
if args.export_file is None:
|
|
logger.warning("No export file provided, output will be sent to STDOUT")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--export-file",
|
|
type=str,
|
|
help="Name of file to export trees to. If not provided, output will be sent to STDOUT",
|
|
)
|
|
parser.add_argument(
|
|
"--ready-only",
|
|
action="store_true",
|
|
help="Only export trees which are ready for use",
|
|
)
|
|
parser.add_argument(
|
|
"--include-deleted",
|
|
action="store_true",
|
|
help="Include deleted messages in export",
|
|
)
|
|
parser.add_argument(
|
|
"--deleted-only",
|
|
action="store_true",
|
|
help="Export only deleted messages (implies --include-deleted)",
|
|
)
|
|
parser.add_argument(
|
|
"--user",
|
|
type=str,
|
|
help="Only export trees involving the user with the specified ID. Incompatible with --ready-only.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
validate_args(args)
|
|
|
|
with Session(engine) as db:
|
|
export_trees(
|
|
db,
|
|
Path(args.export_file) if args.export_file is not None else None,
|
|
args.use_compression,
|
|
args.ready_only,
|
|
args.include_deleted,
|
|
args.deleted_only,
|
|
UUID(args.user) if args.user is not None else None,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|