mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add spam-only export option
This commit is contained in:
+27
-2
@@ -8,7 +8,7 @@ from oasst_backend.database import engine
|
||||
from oasst_backend.models import Message, MessageTreeState
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_backend.utils import tree_export
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, not_
|
||||
|
||||
|
||||
def fetch_tree_ids(
|
||||
@@ -38,6 +38,7 @@ def fetch_tree_messages(
|
||||
deleted: bool = None,
|
||||
prompts_only: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
review_result: Optional[bool] = None,
|
||||
) -> List[Message]:
|
||||
qry = db.query(Message)
|
||||
|
||||
@@ -51,6 +52,10 @@ def fetch_tree_messages(
|
||||
qry = qry.filter(Message.parent_id.is_(None))
|
||||
if lang:
|
||||
qry = qry.filter(Message.lang == lang)
|
||||
if review_result is False:
|
||||
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
|
||||
elif review_result is True:
|
||||
qry = qry.filter(Message.review_result)
|
||||
|
||||
return qry.all()
|
||||
|
||||
@@ -64,16 +69,18 @@ def export_trees(
|
||||
prompts_only: bool = False,
|
||||
state_filter: Optional[TreeState] = None,
|
||||
lang: Optional[str] = None,
|
||||
review_result: Optional[bool] = None,
|
||||
) -> None:
|
||||
trees_to_export: List[tree_export.ExportMessageTree] = []
|
||||
|
||||
if user_id:
|
||||
if user_id or review_result is False:
|
||||
messages = fetch_tree_messages(
|
||||
db,
|
||||
user_id=user_id,
|
||||
deleted=deleted,
|
||||
prompts_only=prompts_only,
|
||||
lang=lang,
|
||||
review_result=review_result,
|
||||
)
|
||||
tree_export.write_messages_to_file(export_file, messages, use_compression)
|
||||
else:
|
||||
@@ -86,6 +93,7 @@ def export_trees(
|
||||
deleted=deleted,
|
||||
prompts_only=prompts_only,
|
||||
lang=None,
|
||||
review_result=review_result,
|
||||
)
|
||||
for (tree_id, _) in message_tree_ids
|
||||
]
|
||||
@@ -135,6 +143,16 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Export only deleted messages (implies --include-deleted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-spam",
|
||||
action="store_true",
|
||||
help="Export only messages with negative review result.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spam-only",
|
||||
action="store_true",
|
||||
help="Export only messages with negative review result (implies --include-spam).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
type=str,
|
||||
@@ -176,6 +194,12 @@ def main():
|
||||
if args.deleted_only:
|
||||
deleted = True
|
||||
|
||||
review_result: Optional[bool] = True
|
||||
if args.include_spam:
|
||||
review_result = None
|
||||
if args.spam_only:
|
||||
review_result = False
|
||||
|
||||
with Session(engine) as db:
|
||||
export_trees(
|
||||
db,
|
||||
@@ -186,6 +210,7 @@ def main():
|
||||
prompts_only=args.prompts_only,
|
||||
state_filter=state_filter,
|
||||
lang=args.lang,
|
||||
review_result=review_result,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class ExportMessageNode(BaseModel):
|
||||
role: str
|
||||
lang: str | None
|
||||
review_count: int | None
|
||||
review_result: bool | None
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
@@ -36,6 +37,7 @@ class ExportMessageNode(BaseModel):
|
||||
role=message.role,
|
||||
lang=message.lang,
|
||||
review_count=message.review_count,
|
||||
review_result=message.review_result if message.review_result or message.review_count > 2 else None,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
emojis=message.emojis,
|
||||
|
||||
Reference in New Issue
Block a user