From cf4c8e70ce37d423acce21f0cf8f6700cf5cb9cd Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Mon, 6 Feb 2023 21:29:30 +0000 Subject: [PATCH] 1184: Add dedicated export.py (#1270) * Add export script --- backend/export.py | 143 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 backend/export.py diff --git a/backend/export.py b/backend/export.py new file mode 100644 index 00000000..ef3fde49 --- /dev/null +++ b/backend/export.py @@ -0,0 +1,143 @@ +import argparse +import json +import sys +from pathlib import Path +from typing import List, Optional +from uuid import UUID + +from fastapi.encoders import jsonable_encoder +from loguru import logger +from oasst_backend.database import engine +from oasst_backend.models import Message, MessageTreeState, message_tree_state +from oasst_backend.utils import tree_export +from sqlmodel import Session, not_ + + +def fetch_tree_ids(db: Session, ready_only: bool = False): + tree_qry = db.query(MessageTreeState) + + if ready_only: + tree_qry = tree_qry.filter(MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT) + + message_trees = tree_qry.all() + + message_tree_ids = [tree.message_tree_id for tree in message_trees] + return message_tree_ids + + +def fetch_message_ids( + db: Session, + message_tree_id: Optional[UUID] = None, + user_id: Optional[UUID] = None, + include_deleted: bool = False, + deleted_only: bool = False, +) -> List[Message]: + qry = db.query(Message) + + if message_tree_id: + qry = qry.filter(Message.message_tree_id == message_tree_id) + if user_id: + qry = qry.filter(Message.user_id == user_id) + if not include_deleted: + qry = qry.filter(not_(Message.deleted)) + if deleted_only: + qry = qry.filter(Message.deleted) + + return qry.all() + + +def export_trees( + db: Session, + export_file: Optional[Path] = None, + use_compression: bool = False, + ready_only: bool = False, + include_deleted: bool = False, + deleted_only: bool = False, + user_id: Optional[UUID] = None, +): + trees_to_export: List[tree_export.ExportMessageTree] = [] + + if user_id: + messages = fetch_message_ids(db, user_id=user_id, include_deleted=include_deleted, deleted_only=deleted_only) + message_tree_ids = [msg.message_tree_id for msg in messages] + else: + message_tree_ids = fetch_tree_ids(db, ready_only) + + message_trees = [ + fetch_message_ids(db, message_tree_id=tree_id, include_deleted=include_deleted, deleted_only=deleted_only) + for tree_id in message_tree_ids + ] + + for message_tree_id, message_tree in zip(message_tree_ids, message_trees): + trees_to_export.append(tree_export.build_export_tree(message_tree_id, message_tree)) + + if export_file: + tree_export.write_trees_to_file(export_file, trees_to_export, use_compression) + else: + sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2)) + + +def validate_args(args): + if args.deleted_only: + args.include_deleted = True + + args.use_compression = args.export_file is not None and ".gz" in args.export_file + + if args.ready_only and args.user_id is not None: + raise ValueError("Cannot use --ready-only when specifying a user ID") + + if args.export_file is None: + logger.warning("No export file provided, output will be sent to STDOUT") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--export-file", + type=str, + help="Name of file to export trees to. If not provided, output will be sent to STDOUT", + ) + parser.add_argument( + "--ready-only", + action="store_true", + help="Only export trees which are ready for use", + ) + parser.add_argument( + "--include-deleted", + action="store_true", + help="Include deleted messages in export", + ) + parser.add_argument( + "--deleted-only", + action="store_true", + help="Export only deleted messages (implies --include-deleted)", + ) + parser.add_argument( + "--user", + type=str, + help="Only export trees involving the user with the specified ID. Incompatible with --ready-only.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + validate_args(args) + + with Session(engine) as db: + export_trees( + db, + Path(args.export_file) if args.export_file is not None else None, + args.use_compression, + args.ready_only, + args.include_deleted, + args.deleted_only, + UUID(args.user) if args.user is not None else None, + ) + + +if __name__ == "__main__": + main()