mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add dupe checks to store_text_reply() & store_text_labels() in PromptRepository (#1018)
* add dupe checks to store_text_reply() & store_text_labels * remove test export file * add user_id to protocol.ConversationMessage * add show_on_leaderboard ot protocol.FrontEndUser
This commit is contained in:
@@ -58,7 +58,7 @@ def get_children_by_frontend_id(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
messages = pr.fetch_message_children(message.id, review_result=None)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ def get_children(
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
messages = pr.fetch_message_children(message_id, review_result=None)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,15 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
id=m.id,
|
||||
frontend_message_id=m.frontend_message_id,
|
||||
parent_id=m.parent_id,
|
||||
user_id=m.user_id,
|
||||
text=m.text,
|
||||
lang=m.lang,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
emojis=m.emojis or {},
|
||||
user_emojis=m.user_emojis or [],
|
||||
review_result=m.review_result,
|
||||
review_count=m.review_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +29,7 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
|
||||
return protocol.ConversationMessage(
|
||||
id=message.id,
|
||||
user_id=message.user_id,
|
||||
frontend_message_id=message.frontend_message_id,
|
||||
text=message.text,
|
||||
lang=message.lang,
|
||||
|
||||
@@ -42,4 +42,5 @@ class User(SQLModel, table=True):
|
||||
deleted=self.deleted,
|
||||
notes=self.notes,
|
||||
created_date=self.created_date,
|
||||
show_on_leaderboard=self.show_on_leaderboard,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
@@ -62,13 +61,11 @@ class PromptRepository:
|
||||
|
||||
if user_id:
|
||||
self.user = self.user_repository.get_user(id=user_id)
|
||||
self.user_id = self.user.id
|
||||
elif auth_method and username:
|
||||
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
|
||||
self.user_id = self.user.id
|
||||
else:
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.user_id = self.user.id if self.user else None
|
||||
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
@@ -215,6 +212,14 @@ class PromptRepository:
|
||||
OasstErrorCode.TREE_NOT_IN_GROWING_STATE,
|
||||
)
|
||||
|
||||
if check_duplicate and not settings.DEBUG_ALLOW_DUPLICATE_TASKS:
|
||||
siblings = self.fetch_message_children(task.parent_message_id, review_result=None, deleted=False)
|
||||
if any(m.user_id == self.user_id for m in siblings):
|
||||
raise OasstError(
|
||||
"User cannot reply twice to the same message.",
|
||||
OasstErrorCode.TASK_MESSAGE_DUPLICATE_REPLY,
|
||||
)
|
||||
|
||||
parent_message.message_tree_id
|
||||
parent_message.children_count += 1
|
||||
self.db.add(parent_message)
|
||||
@@ -419,6 +424,7 @@ class PromptRepository:
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
valid_labels: Optional[list[str]] = None
|
||||
mandatory_labels: Optional[list[str]] = None
|
||||
@@ -484,6 +490,8 @@ class PromptRepository:
|
||||
message: Message = None
|
||||
if message_id:
|
||||
if not task:
|
||||
# free labeling case
|
||||
|
||||
if text_labels.is_report is True:
|
||||
message = self.handle_message_emoji(
|
||||
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
|
||||
@@ -496,7 +504,21 @@ class PromptRepository:
|
||||
model = existing_text_label
|
||||
|
||||
else:
|
||||
message = self.fetch_message(message_id)
|
||||
# task based labeling case
|
||||
|
||||
message = self.fetch_message(message_id, fail_if_missing=True)
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING and message.user_id == self.user_id:
|
||||
raise OasstError(
|
||||
"Labeling own message is not allowed.", OasstErrorCode.TEXT_LABELS_NO_SELF_LABELING
|
||||
)
|
||||
|
||||
existing_labels = self.fetch_message_text_labels(message_id, self.user_id)
|
||||
if not settings.DEBUG_ALLOW_DUPLICATE_TASKS and any(l.task_id for l in existing_labels):
|
||||
raise OasstError(
|
||||
"Message was already labeled by same user before.",
|
||||
OasstErrorCode.TEXT_LABELS_DUPLICATE_TASK_REPLY,
|
||||
)
|
||||
|
||||
message.review_count += 1
|
||||
self.db.add(message)
|
||||
|
||||
@@ -666,6 +688,12 @@ class PromptRepository:
|
||||
text_label = query.one_or_none()
|
||||
return text_label
|
||||
|
||||
def fetch_message_text_labels(self, message_id: UUID, user_id: Optional[UUID] = None) -> list[TextLabels]:
|
||||
query = self.db.query(TextLabels).filter(TextLabels.message_id == message_id)
|
||||
if user_id is not None:
|
||||
query = query.filter(TextLabels.user_id == user_id)
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
@@ -712,7 +740,10 @@ class PromptRepository:
|
||||
return self.fetch_message_tree(message.message_tree_id)
|
||||
|
||||
def fetch_message_children(
|
||||
self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True
|
||||
self,
|
||||
message: Message | UUID,
|
||||
review_result: Optional[bool] = True,
|
||||
deleted: Optional[bool] = False,
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Get all direct children of this message
|
||||
@@ -721,26 +752,31 @@ class PromptRepository:
|
||||
message = message.id
|
||||
|
||||
qry = self.db.query(Message).filter(Message.parent_id == message)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if exclude_deleted:
|
||||
qry = qry.filter(Message.deleted == sa.false())
|
||||
if review_result is not None:
|
||||
qry = qry.filter(Message.review_result == review_result)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
children = self._add_user_emojis_all(qry)
|
||||
return children
|
||||
|
||||
def fetch_message_siblings(
|
||||
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
|
||||
self,
|
||||
message: Message | UUID,
|
||||
review_result: Optional[bool] = True,
|
||||
deleted: Optional[bool] = False,
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Get siblings of a message (other messages with the same parent_id)
|
||||
"""
|
||||
qry = self.db.query(Message)
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
qry = qry.filter(Message.parent_id == message.parent_id)
|
||||
else:
|
||||
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
|
||||
qry = qry.filter(Message.parent_id == parent_qry.c.parent_id)
|
||||
|
||||
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
|
||||
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
|
||||
if reviewed is not None:
|
||||
qry = qry.filter(Message.review_result == reviewed)
|
||||
if review_result is not None:
|
||||
qry = qry.filter(Message.review_result == review_result)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
siblings = self._add_user_emojis_all(qry)
|
||||
|
||||
@@ -319,7 +319,7 @@ class TreeManager:
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, review_result=True, deleted=False)
|
||||
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
@@ -756,7 +756,7 @@ class TreeManager:
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
|
||||
# fetch all siblings and clear ranks
|
||||
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
|
||||
siblings = self.pr.fetch_message_siblings(consensus[0], review_result=None, deleted=None)
|
||||
for m in siblings:
|
||||
m.rank = None
|
||||
self.db.add(m)
|
||||
|
||||
@@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_MESSAGE_TOO_LONG = 1008
|
||||
TASK_MESSAGE_DUPLICATED = 1009
|
||||
TASK_MESSAGE_TEXT_EMPTY = 1010
|
||||
TASK_MESSAGE_DUPLICATE_REPLY = 1011
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
@@ -59,6 +60,8 @@ class OasstErrorCode(IntEnum):
|
||||
TEXT_LABELS_WRONG_MESSAGE_ID = 2050
|
||||
TEXT_LABELS_INVALID_LABEL = 2051
|
||||
TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052
|
||||
TEXT_LABELS_NO_SELF_LABELING = 2053
|
||||
TEXT_LABELS_DUPLICATE_TASK_REPLY = 2053
|
||||
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
|
||||
@@ -35,6 +35,7 @@ class FrontEndUser(User):
|
||||
deleted: bool
|
||||
notes: str
|
||||
created_date: Optional[datetime] = None
|
||||
show_on_leaderboard: bool
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
@@ -53,6 +54,7 @@ class ConversationMessage(BaseModel):
|
||||
"""Represents a message in a conversation between the user and the assistant."""
|
||||
|
||||
id: Optional[UUID] = None
|
||||
user_id: Optional[UUID]
|
||||
frontend_message_id: Optional[str] = None
|
||||
text: str
|
||||
lang: Optional[str] # BCP 47
|
||||
@@ -80,8 +82,10 @@ class Conversation(BaseModel):
|
||||
|
||||
|
||||
class Message(ConversationMessage):
|
||||
parent_id: Optional[UUID] = None
|
||||
created_date: Optional[datetime] = None
|
||||
parent_id: Optional[UUID]
|
||||
created_date: Optional[datetime]
|
||||
review_result: Optional[bool]
|
||||
review_count: Optional[int]
|
||||
|
||||
|
||||
class MessagePage(PageResult):
|
||||
|
||||
Reference in New Issue
Block a user