add spam-only export option

This commit is contained in:
Andreas Köpf
2023-02-08 22:23:42 +01:00
parent 8bad8c32cd
commit af6885d416
2 changed files with 29 additions and 2 deletions
+27 -2
View File
@@ -8,7 +8,7 @@ from oasst_backend.database import engine
from oasst_backend.models import Message, MessageTreeState
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.utils import tree_export
from sqlmodel import Session
from sqlmodel import Session, not_
def fetch_tree_ids(
@@ -38,6 +38,7 @@ def fetch_tree_messages(
deleted: bool = None,
prompts_only: bool = False,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> List[Message]:
qry = db.query(Message)
@@ -51,6 +52,10 @@ def fetch_tree_messages(
qry = qry.filter(Message.parent_id.is_(None))
if lang:
qry = qry.filter(Message.lang == lang)
if review_result is False:
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
elif review_result is True:
qry = qry.filter(Message.review_result)
return qry.all()
@@ -64,16 +69,18 @@ def export_trees(
prompts_only: bool = False,
state_filter: Optional[TreeState] = None,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> None:
trees_to_export: List[tree_export.ExportMessageTree] = []
if user_id:
if user_id or review_result is False:
messages = fetch_tree_messages(
db,
user_id=user_id,
deleted=deleted,
prompts_only=prompts_only,
lang=lang,
review_result=review_result,
)
tree_export.write_messages_to_file(export_file, messages, use_compression)
else:
@@ -86,6 +93,7 @@ def export_trees(
deleted=deleted,
prompts_only=prompts_only,
lang=None,
review_result=review_result,
)
for (tree_id, _) in message_tree_ids
]
@@ -135,6 +143,16 @@ def parse_args():
action="store_true",
help="Export only deleted messages (implies --include-deleted)",
)
parser.add_argument(
"--include-spam",
action="store_true",
help="Export only messages with negative review result.",
)
parser.add_argument(
"--spam-only",
action="store_true",
help="Export only messages with negative review result (implies --include-spam).",
)
parser.add_argument(
"--user",
type=str,
@@ -176,6 +194,12 @@ def main():
if args.deleted_only:
deleted = True
review_result: Optional[bool] = True
if args.include_spam:
review_result = None
if args.spam_only:
review_result = False
with Session(engine) as db:
export_trees(
db,
@@ -186,6 +210,7 @@ def main():
prompts_only=args.prompts_only,
state_filter=state_filter,
lang=args.lang,
review_result=review_result,
)
@@ -21,6 +21,7 @@ class ExportMessageNode(BaseModel):
role: str
lang: str | None
review_count: int | None
review_result: bool | None
rank: int | None
synthetic: bool | None
model_name: str | None
@@ -36,6 +37,7 @@ class ExportMessageNode(BaseModel):
role=message.role,
lang=message.lang,
review_count=message.review_count,
review_result=message.review_result if message.review_result or message.review_count > 2 else None,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,