diff --git a/backend/export.py b/backend/export.py index c0164172..2c573551 100644 --- a/backend/export.py +++ b/backend/export.py @@ -8,7 +8,7 @@ from oasst_backend.database import engine from oasst_backend.models import Message, MessageTreeState from oasst_backend.models.message_tree_state import State as TreeState from oasst_backend.utils import tree_export -from sqlmodel import Session +from sqlmodel import Session, not_ def fetch_tree_ids( @@ -38,6 +38,7 @@ def fetch_tree_messages( deleted: bool = None, prompts_only: bool = False, lang: Optional[str] = None, + review_result: Optional[bool] = None, ) -> List[Message]: qry = db.query(Message) @@ -51,6 +52,10 @@ def fetch_tree_messages( qry = qry.filter(Message.parent_id.is_(None)) if lang: qry = qry.filter(Message.lang == lang) + if review_result is False: + qry = qry.filter(not_(Message.review_result), Message.review_count > 2) + elif review_result is True: + qry = qry.filter(Message.review_result) return qry.all() @@ -64,16 +69,18 @@ def export_trees( prompts_only: bool = False, state_filter: Optional[TreeState] = None, lang: Optional[str] = None, + review_result: Optional[bool] = None, ) -> None: trees_to_export: List[tree_export.ExportMessageTree] = [] - if user_id: + if user_id or review_result is False: messages = fetch_tree_messages( db, user_id=user_id, deleted=deleted, prompts_only=prompts_only, lang=lang, + review_result=review_result, ) tree_export.write_messages_to_file(export_file, messages, use_compression) else: @@ -86,6 +93,7 @@ def export_trees( deleted=deleted, prompts_only=prompts_only, lang=None, + review_result=review_result, ) for (tree_id, _) in message_tree_ids ] @@ -135,6 +143,16 @@ def parse_args(): action="store_true", help="Export only deleted messages (implies --include-deleted)", ) + parser.add_argument( + "--include-spam", + action="store_true", + help="Export only messages with negative review result.", + ) + parser.add_argument( + "--spam-only", + action="store_true", + help="Export only messages with negative review result (implies --include-spam).", + ) parser.add_argument( "--user", type=str, @@ -176,6 +194,12 @@ def main(): if args.deleted_only: deleted = True + review_result: Optional[bool] = True + if args.include_spam: + review_result = None + if args.spam_only: + review_result = False + with Session(engine) as db: export_trees( db, @@ -186,6 +210,7 @@ def main(): prompts_only=args.prompts_only, state_filter=state_filter, lang=args.lang, + review_result=review_result, ) diff --git a/backend/oasst_backend/utils/tree_export.py b/backend/oasst_backend/utils/tree_export.py index 409d8a22..4aaba92b 100644 --- a/backend/oasst_backend/utils/tree_export.py +++ b/backend/oasst_backend/utils/tree_export.py @@ -21,6 +21,7 @@ class ExportMessageNode(BaseModel): role: str lang: str | None review_count: int | None + review_result: bool | None rank: int | None synthetic: bool | None model_name: str | None @@ -36,6 +37,7 @@ class ExportMessageNode(BaseModel): role=message.role, lang=message.lang, review_count=message.review_count, + review_result=message.review_result if message.review_result or message.review_count > 2 else None, synthetic=message.synthetic, model_name=message.model_name, emojis=message.emojis,