Add dupe checks to store_text_reply() & store_text_labels() in PromptRepository (#1018)

* add dupe checks to store_text_reply() & store_text_labels

* remove test export file

* add user_id to protocol.ConversationMessage

* add show_on_leaderboard ot protocol.FrontEndUser
This commit is contained in:
Andreas Köpf
2023-01-30 20:53:59 +01:00
committed by GitHub
parent b8a62e5f4f
commit a5bc9bf492
8 changed files with 70 additions and 22 deletions
@@ -58,7 +58,7 @@ def get_children_by_frontend_id(
"""
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
messages = pr.fetch_message_children(message.id, review_result=None)
return utils.prepare_message_list(messages)
+1 -1
View File
@@ -201,7 +201,7 @@ def get_children(
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
messages = pr.fetch_message_children(message_id)
messages = pr.fetch_message_children(message_id, review_result=None)
return utils.prepare_message_list(messages)
+4
View File
@@ -10,12 +10,15 @@ def prepare_message(m: Message) -> protocol.Message:
id=m.id,
frontend_message_id=m.frontend_message_id,
parent_id=m.parent_id,
user_id=m.user_id,
text=m.text,
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
emojis=m.emojis or {},
user_emojis=m.user_emojis or [],
review_result=m.review_result,
review_count=m.review_count,
)
@@ -26,6 +29,7 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
return protocol.ConversationMessage(
id=message.id,
user_id=message.user_id,
frontend_message_id=message.frontend_message_id,
text=message.text,
lang=message.lang,
+1
View File
@@ -42,4 +42,5 @@ class User(SQLModel, table=True):
deleted=self.deleted,
notes=self.notes,
created_date=self.created_date,
show_on_leaderboard=self.show_on_leaderboard,
)
+52 -16
View File
@@ -7,7 +7,6 @@ from typing import Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
import sqlalchemy as sa
from loguru import logger
from oasst_backend.api.deps import FrontendUserId
from oasst_backend.config import settings
@@ -62,13 +61,11 @@ class PromptRepository:
if user_id:
self.user = self.user_repository.get_user(id=user_id)
self.user_id = self.user.id
elif auth_method and username:
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
self.user_id = self.user.id
else:
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
self.user_id = self.user.id if self.user else None
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
@@ -215,6 +212,14 @@ class PromptRepository:
OasstErrorCode.TREE_NOT_IN_GROWING_STATE,
)
if check_duplicate and not settings.DEBUG_ALLOW_DUPLICATE_TASKS:
siblings = self.fetch_message_children(task.parent_message_id, review_result=None, deleted=False)
if any(m.user_id == self.user_id for m in siblings):
raise OasstError(
"User cannot reply twice to the same message.",
OasstErrorCode.TASK_MESSAGE_DUPLICATE_REPLY,
)
parent_message.message_tree_id
parent_message.children_count += 1
self.db.add(parent_message)
@@ -419,6 +424,7 @@ class PromptRepository:
@managed_tx_method(CommitMode.FLUSH)
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]:
self.ensure_user_is_enabled()
valid_labels: Optional[list[str]] = None
mandatory_labels: Optional[list[str]] = None
@@ -484,6 +490,8 @@ class PromptRepository:
message: Message = None
if message_id:
if not task:
# free labeling case
if text_labels.is_report is True:
message = self.handle_message_emoji(
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
@@ -496,7 +504,21 @@ class PromptRepository:
model = existing_text_label
else:
message = self.fetch_message(message_id)
# task based labeling case
message = self.fetch_message(message_id, fail_if_missing=True)
if not settings.DEBUG_ALLOW_SELF_LABELING and message.user_id == self.user_id:
raise OasstError(
"Labeling own message is not allowed.", OasstErrorCode.TEXT_LABELS_NO_SELF_LABELING
)
existing_labels = self.fetch_message_text_labels(message_id, self.user_id)
if not settings.DEBUG_ALLOW_DUPLICATE_TASKS and any(l.task_id for l in existing_labels):
raise OasstError(
"Message was already labeled by same user before.",
OasstErrorCode.TEXT_LABELS_DUPLICATE_TASK_REPLY,
)
message.review_count += 1
self.db.add(message)
@@ -666,6 +688,12 @@ class PromptRepository:
text_label = query.one_or_none()
return text_label
def fetch_message_text_labels(self, message_id: UUID, user_id: Optional[UUID] = None) -> list[TextLabels]:
query = self.db.query(TextLabels).filter(TextLabels.message_id == message_id)
if user_id is not None:
query = query.filter(TextLabels.user_id == user_id)
return query.all()
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
@@ -712,7 +740,10 @@ class PromptRepository:
return self.fetch_message_tree(message.message_tree_id)
def fetch_message_children(
self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True
self,
message: Message | UUID,
review_result: Optional[bool] = True,
deleted: Optional[bool] = False,
) -> list[Message]:
"""
Get all direct children of this message
@@ -721,26 +752,31 @@ class PromptRepository:
message = message.id
qry = self.db.query(Message).filter(Message.parent_id == message)
if reviewed:
qry = qry.filter(Message.review_result)
if exclude_deleted:
qry = qry.filter(Message.deleted == sa.false())
if review_result is not None:
qry = qry.filter(Message.review_result == review_result)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
children = self._add_user_emojis_all(qry)
return children
def fetch_message_siblings(
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
self,
message: Message | UUID,
review_result: Optional[bool] = True,
deleted: Optional[bool] = False,
) -> list[Message]:
"""
Get siblings of a message (other messages with the same parent_id)
"""
qry = self.db.query(Message)
if isinstance(message, Message):
message = message.id
qry = qry.filter(Message.parent_id == message.parent_id)
else:
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
qry = qry.filter(Message.parent_id == parent_qry.c.parent_id)
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
if reviewed is not None:
qry = qry.filter(Message.review_result == reviewed)
if review_result is not None:
qry = qry.filter(Message.review_result == review_result)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
siblings = self._add_user_emojis_all(qry)
+2 -2
View File
@@ -319,7 +319,7 @@ class TreeManager:
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
replies = self.pr.fetch_message_children(ranking_parent_id, review_result=True, deleted=False)
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
@@ -756,7 +756,7 @@ class TreeManager:
logger.debug(f"CONSENSUS: {consensus}\n\n")
# fetch all siblings and clear ranks
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
siblings = self.pr.fetch_message_siblings(consensus[0], review_result=None, deleted=None)
for m in siblings:
m.rank = None
self.db.add(m)
@@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum):
TASK_MESSAGE_TOO_LONG = 1008
TASK_MESSAGE_DUPLICATED = 1009
TASK_MESSAGE_TEXT_EMPTY = 1010
TASK_MESSAGE_DUPLICATE_REPLY = 1011
# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000
@@ -59,6 +60,8 @@ class OasstErrorCode(IntEnum):
TEXT_LABELS_WRONG_MESSAGE_ID = 2050
TEXT_LABELS_INVALID_LABEL = 2051
TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052
TEXT_LABELS_NO_SELF_LABELING = 2053
TEXT_LABELS_DUPLICATE_TASK_REPLY = 2053
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
@@ -35,6 +35,7 @@ class FrontEndUser(User):
deleted: bool
notes: str
created_date: Optional[datetime] = None
show_on_leaderboard: bool
class PageResult(BaseModel):
@@ -53,6 +54,7 @@ class ConversationMessage(BaseModel):
"""Represents a message in a conversation between the user and the assistant."""
id: Optional[UUID] = None
user_id: Optional[UUID]
frontend_message_id: Optional[str] = None
text: str
lang: Optional[str] # BCP 47
@@ -80,8 +82,10 @@ class Conversation(BaseModel):
class Message(ConversationMessage):
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
parent_id: Optional[UUID]
created_date: Optional[datetime]
review_result: Optional[bool]
review_count: Optional[int]
class MessagePage(PageResult):