diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index dffb2824..9ee555ce 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -58,7 +58,7 @@ def get_children_by_frontend_id( """ pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) - messages = pr.fetch_message_children(message.id) + messages = pr.fetch_message_children(message.id, review_result=None) return utils.prepare_message_list(messages) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 1fbaf53c..29468bf1 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -201,7 +201,7 @@ def get_children( Get all messages belonging to the same message tree. """ pr = PromptRepository(db, api_client, frontend_user=frontend_user) - messages = pr.fetch_message_children(message_id) + messages = pr.fetch_message_children(message_id, review_result=None) return utils.prepare_message_list(messages) diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index b9e982ab..245114ef 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -10,12 +10,15 @@ def prepare_message(m: Message) -> protocol.Message: id=m.id, frontend_message_id=m.frontend_message_id, parent_id=m.parent_id, + user_id=m.user_id, text=m.text, lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, emojis=m.emojis or {}, user_emojis=m.user_emojis or [], + review_result=m.review_result, + review_count=m.review_count, ) @@ -26,6 +29,7 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]: def prepare_conversation_message(message: Message) -> protocol.ConversationMessage: return protocol.ConversationMessage( id=message.id, + user_id=message.user_id, frontend_message_id=message.frontend_message_id, text=message.text, lang=message.lang, diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 6b6dced4..d8cd1a39 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -42,4 +42,5 @@ class User(SQLModel, table=True): deleted=self.deleted, notes=self.notes, created_date=self.created_date, + show_on_leaderboard=self.show_on_leaderboard, ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 50808285..e889e73b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -7,7 +7,6 @@ from typing import Optional from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload -import sqlalchemy as sa from loguru import logger from oasst_backend.api.deps import FrontendUserId from oasst_backend.config import settings @@ -62,13 +61,11 @@ class PromptRepository: if user_id: self.user = self.user_repository.get_user(id=user_id) - self.user_id = self.user.id elif auth_method and username: self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username) - self.user_id = self.user.id else: self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) - self.user_id = self.user.id if self.user else None + self.user_id = self.user.id if self.user else None logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})") self.task_repository = task_repository or TaskRepository( db, api_client, client_user, user_repository=self.user_repository @@ -215,6 +212,14 @@ class PromptRepository: OasstErrorCode.TREE_NOT_IN_GROWING_STATE, ) + if check_duplicate and not settings.DEBUG_ALLOW_DUPLICATE_TASKS: + siblings = self.fetch_message_children(task.parent_message_id, review_result=None, deleted=False) + if any(m.user_id == self.user_id for m in siblings): + raise OasstError( + "User cannot reply twice to the same message.", + OasstErrorCode.TASK_MESSAGE_DUPLICATE_REPLY, + ) + parent_message.message_tree_id parent_message.children_count += 1 self.db.add(parent_message) @@ -419,6 +424,7 @@ class PromptRepository: @managed_tx_method(CommitMode.FLUSH) def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]: + self.ensure_user_is_enabled() valid_labels: Optional[list[str]] = None mandatory_labels: Optional[list[str]] = None @@ -484,6 +490,8 @@ class PromptRepository: message: Message = None if message_id: if not task: + # free labeling case + if text_labels.is_report is True: message = self.handle_message_emoji( message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag @@ -496,7 +504,21 @@ class PromptRepository: model = existing_text_label else: - message = self.fetch_message(message_id) + # task based labeling case + + message = self.fetch_message(message_id, fail_if_missing=True) + if not settings.DEBUG_ALLOW_SELF_LABELING and message.user_id == self.user_id: + raise OasstError( + "Labeling own message is not allowed.", OasstErrorCode.TEXT_LABELS_NO_SELF_LABELING + ) + + existing_labels = self.fetch_message_text_labels(message_id, self.user_id) + if not settings.DEBUG_ALLOW_DUPLICATE_TASKS and any(l.task_id for l in existing_labels): + raise OasstError( + "Message was already labeled by same user before.", + OasstErrorCode.TEXT_LABELS_DUPLICATE_TASK_REPLY, + ) + message.review_count += 1 self.db.add(message) @@ -666,6 +688,12 @@ class PromptRepository: text_label = query.one_or_none() return text_label + def fetch_message_text_labels(self, message_id: UUID, user_id: Optional[UUID] = None) -> list[TextLabels]: + query = self.db.query(TextLabels).filter(TextLabels.message_id == message_id) + if user_id is not None: + query = query.filter(TextLabels.user_id == user_id) + return query.all() + @staticmethod def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """ @@ -712,7 +740,10 @@ class PromptRepository: return self.fetch_message_tree(message.message_tree_id) def fetch_message_children( - self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True + self, + message: Message | UUID, + review_result: Optional[bool] = True, + deleted: Optional[bool] = False, ) -> list[Message]: """ Get all direct children of this message @@ -721,26 +752,31 @@ class PromptRepository: message = message.id qry = self.db.query(Message).filter(Message.parent_id == message) - if reviewed: - qry = qry.filter(Message.review_result) - if exclude_deleted: - qry = qry.filter(Message.deleted == sa.false()) + if review_result is not None: + qry = qry.filter(Message.review_result == review_result) + if deleted is not None: + qry = qry.filter(Message.deleted == deleted) children = self._add_user_emojis_all(qry) return children def fetch_message_siblings( - self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False + self, + message: Message | UUID, + review_result: Optional[bool] = True, + deleted: Optional[bool] = False, ) -> list[Message]: """ Get siblings of a message (other messages with the same parent_id) """ + qry = self.db.query(Message) if isinstance(message, Message): - message = message.id + qry = qry.filter(Message.parent_id == message.parent_id) + else: + parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery() + qry = qry.filter(Message.parent_id == parent_qry.c.parent_id) - parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery() - qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id) - if reviewed is not None: - qry = qry.filter(Message.review_result == reviewed) + if review_result is not None: + qry = qry.filter(Message.review_result == review_result) if deleted is not None: qry = qry.filter(Message.deleted == deleted) siblings = self._add_user_emojis_all(qry) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index e917a57c..4c7a7df6 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -319,7 +319,7 @@ class TreeManager: ranking_parent = messages[-1] assert not ranking_parent.deleted and ranking_parent.review_result conversation = prepare_conversation(messages) - replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) + replies = self.pr.fetch_message_children(ranking_parent_id, review_result=True, deleted=False) assert len(replies) > 1 random.shuffle(replies) # hand out replies in random order @@ -756,7 +756,7 @@ class TreeManager: logger.debug(f"CONSENSUS: {consensus}\n\n") # fetch all siblings and clear ranks - siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None) + siblings = self.pr.fetch_message_siblings(consensus[0], review_result=None, deleted=None) for m in siblings: m.rank = None self.db.add(m) diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index e7f5dcd5..2c3650a6 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum): TASK_MESSAGE_TOO_LONG = 1008 TASK_MESSAGE_DUPLICATED = 1009 TASK_MESSAGE_TEXT_EMPTY = 1010 + TASK_MESSAGE_DUPLICATE_REPLY = 1011 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000 @@ -59,6 +60,8 @@ class OasstErrorCode(IntEnum): TEXT_LABELS_WRONG_MESSAGE_ID = 2050 TEXT_LABELS_INVALID_LABEL = 2051 TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052 + TEXT_LABELS_NO_SELF_LABELING = 2053 + TEXT_LABELS_DUPLICATE_TASK_REPLY = 2053 TASK_NOT_FOUND = 2100 TASK_EXPIRED = 2101 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 93d533a6..a89252fb 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -35,6 +35,7 @@ class FrontEndUser(User): deleted: bool notes: str created_date: Optional[datetime] = None + show_on_leaderboard: bool class PageResult(BaseModel): @@ -53,6 +54,7 @@ class ConversationMessage(BaseModel): """Represents a message in a conversation between the user and the assistant.""" id: Optional[UUID] = None + user_id: Optional[UUID] frontend_message_id: Optional[str] = None text: str lang: Optional[str] # BCP 47 @@ -80,8 +82,10 @@ class Conversation(BaseModel): class Message(ConversationMessage): - parent_id: Optional[UUID] = None - created_date: Optional[datetime] = None + parent_id: Optional[UUID] + created_date: Optional[datetime] + review_result: Optional[bool] + review_count: Optional[int] class MessagePage(PageResult):