From c2fa476904552ceca5675568f7645cae22de26fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 15:29:54 +0100 Subject: [PATCH] Add user emoji augmentation for message queries (#937) * add disposition to text labeling tasks * add emoji stats to ConversationMessage * add user emoji augmentation for message queries * add auth_method,username to message queries (query emoji status) * add auth_method+username for single message * fix param name typo * only join rows when message.emojis != JSON.NULL * formatting * make sure emojis and user_emojis default to {}, [] * remove init_user(), use fresh empty default collections --- .../oasst_backend/api/v1/frontend_users.py | 2 +- backend/oasst_backend/api/v1/messages.py | 60 ++++++++++++----- backend/oasst_backend/api/v1/users.py | 2 +- backend/oasst_backend/api/v1/utils.py | 5 +- backend/oasst_backend/models/message.py | 17 ++++- backend/oasst_backend/prompt_repository.py | 67 ++++++++++++++++--- backend/oasst_backend/tree_manager.py | 8 +++ oasst-shared/oasst_shared/schemas/protocol.py | 28 +++++--- 8 files changed, 149 insertions(+), 40 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 5ea7b26c..86f78026 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -77,7 +77,7 @@ def query_frontend_user_messages( """ Query frontend user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index af3ae42d..b3aace40 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -34,7 +34,7 @@ def query_messages( """ Query messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, @@ -93,7 +93,7 @@ def get_messages_cursor( qry_max_count = max_count + 1 if before is None or after is None else max_count - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id) items = pr.query_messages_ordered_by_created_date( user_id=user_id, auth_method=auth_method, @@ -137,37 +137,49 @@ def get_messages_cursor( @router.get("/{message_id}", response_model=protocol.Message) def get_message( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a message by its internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) return utils.prepare_message(message) @router.get("/{message_id}/conversation", response_model=protocol.Conversation) def get_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a conversation from the tree root and up to the message with given internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_conversation(message_id) return utils.prepare_conversation(messages) @router.get("/{message_id}/tree", response_model=protocol.MessageTree) def get_tree( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False) return utils.prepare_tree(tree, message.message_tree_id) @@ -175,24 +187,32 @@ def get_tree( @router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_children(message_id) return utils.prepare_message_list(messages) @router.get("/{message_id}/descendants", response_model=protocol.MessageTree) def get_descendants( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a subtree which starts with this message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -200,12 +220,16 @@ def get_descendants( @router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation) def get_longest_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get the longest conversation from the tree of the message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -213,12 +237,16 @@ def get_longest_conv( @router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree) def get_max_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get message with the most children from the tree of the provided message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index c0055339..d7497610 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -230,7 +230,7 @@ def query_user_messages( """ Query user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, user_id=user_id) messages = pr.query_messages_ordered_by_created_date( user_id=user_id, api_client_id=api_client_id, diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 8b0f378f..5c9537a3 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message: lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, - emojis=m.emojis, + emojis=m.emojis or {}, + user_emojis=m.user_emojis or [], ) @@ -30,6 +31,8 @@ def prepare_conversation_message_list(messages: list[Message]) -> list[protocol. text=message.text, lang=message.lang, is_assistant=(message.role == "assistant"), + emojis=message.emojis or {}, + user_emojis=message.user_emojis or [], ) for message in messages ] diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index da0c06c3..5f323d5d 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,12 +1,13 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import Any, Optional from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from pydantic import PrivateAttr from sqlalchemy import false from sqlmodel import Field, Index, SQLModel @@ -17,6 +18,13 @@ class Message(SQLModel, table=True): __tablename__ = "message" __table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),) + def __new__(cls, *args: Any, **kwargs: Any): + new_object = super().__new__(cls, *args, **kwargs) + # temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged + if not hasattr(new_object, "_user_emojis"): + new_object._init_private_attributes() + return new_object + id: Optional[UUID] = Field( sa_column=sa.Column( pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") @@ -49,7 +57,8 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) - emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) + emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False) + _user_emojis: Optional[list[str]] = PrivateAttr(default=None) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): @@ -59,3 +68,7 @@ class Message(SQLModel, table=True): def text(self) -> str: self.ensure_is_message() return self.payload.payload.text + + @property + def user_emojis(self) -> str: + return self._user_emojis diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7dddb5cf..b31b53d7 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -30,8 +30,9 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats from oasst_shared.utils import unaware_to_utc +from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified -from sqlmodel import Session, and_, func, not_, or_, text, update +from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -41,14 +42,25 @@ class PromptRepository: db: Session, api_client: ApiClient, client_user: Optional[protocol_schema.User] = None, + *, user_repository: Optional[UserRepository] = None, task_repository: Optional[TaskRepository] = None, + user_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + username: Optional[str] = None, ): self.db = db self.api_client = api_client self.user_repository = user_repository or UserRepository(db, api_client) - self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) - self.user_id = self.user.id if self.user else None + if user_id: + self.user = self.user_repository.get_user(id=user_id) + self.user_id = self.user.id + elif auth_method and username: + self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username) + self.user_id = self.user.id + else: + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) + self.user_id = self.user.id if self.user else None logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})") self.task_repository = task_repository or TaskRepository( db, api_client, client_user, user_repository=self.user_repository @@ -529,7 +541,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_user_message_trees( self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False @@ -539,7 +551,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]: qry = self.db.query(MessageTreeState).filter( @@ -582,6 +594,10 @@ class PromptRepository: return conversation, replies def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: + qry = self.db.query(Message).filter(Message.id == message_id) + messages = self._add_user_emojis_all(qry) + message = messages[0] if messages else None + message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) @@ -656,7 +672,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if exclude_deleted: qry = qry.filter(Message.deleted == sa.false()) - children = qry.all() + children = self._add_user_emojis_all(qry) return children def fetch_message_siblings( @@ -674,7 +690,7 @@ class PromptRepository: qry = qry.filter(Message.review_result == reviewed) if deleted is not None: qry = qry.filter(Message.deleted == deleted) - siblings = qry.all() + siblings = self._add_user_emojis_all(qry) return siblings @staticmethod @@ -705,7 +721,7 @@ class PromptRepository: if max_depth is not None: desc = desc.filter(Message.depth <= max_depth) - desc = desc.all() + desc = self._add_user_emojis_all(desc) return self.trace_descendants(message, desc) @@ -719,6 +735,33 @@ class PromptRepository: max_message = max(tree, key=lambda m: m.children_count) return max_message, [m for m in tree if m.parent_id == max_message.id] + def _add_user_emojis_all(self, qry: Query) -> list[Message]: + if self.user_id is None: + return qry.all() + + sq = qry.subquery("m") + qry = ( + self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis")) + .select_entity_from(sq) + .outerjoin( + MessageEmoji, + and_( + sq.c.id == MessageEmoji.message_id, + MessageEmoji.user_id == self.user_id, + sq.c.emojis != JSON.NULL, + ), + ) + .group_by(sq) + ) + messages: list[Message] = [] + for x in qry: + m: Message = x.Message + user_emojis = x["user_emojis"] + if user_emojis: + m._user_emojis = user_emojis.split(",") + messages.append(m) + return messages + def query_messages_ordered_by_created_date( self, user_id: Optional[UUID] = None, @@ -801,7 +844,7 @@ class PromptRepository: if lang is not None: qry = qry.filter(Message.lang == lang) - return qry.all() + return self._add_user_emojis_all(qry) def update_children_counts(self, message_tree_id: UUID): sql_update_children_count = """ @@ -902,9 +945,15 @@ WHERE message.id = cc.id; else: count = emoji_counts.get(emoji.value) or 0 emoji_counts[emoji.value] = count + 1 + if message._user_emojis is None: + message._user_emojis = [] + if emoji.value not in message._user_emojis: + message._user_emojis.append(emoji.value) elif op == protocol_schema.EmojiOp.remove: # remove emoji record and & decrement count message = self.fetch_message(message_id) + if message._user_emojis and emoji.value in message._user_emojis: + message._user_emojis.remove(emoji.value) self.db.delete(existing_emoji) emoji_counts = message.emojis count = emoji_counts.get(emoji.value) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 64e1883e..929a9297 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -354,6 +354,7 @@ class TreeManager: self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if message.role == "assistant": @@ -363,6 +364,8 @@ class TreeManager: ): valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, @@ -371,6 +374,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, + disposition=label_disposition, ) else: if ( @@ -387,6 +391,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id @@ -424,11 +429,13 @@ class TreeManager: message = random.choice(prompts_need_review) label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if random.random() > self.cfg.p_full_labeling_review_prompt: valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).") task = protocol_schema.LabelInitialPromptTask( @@ -437,6 +444,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index de431d75..31caa340 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -57,6 +57,8 @@ class ConversationMessage(BaseModel): text: str lang: Optional[str] # BCP 47 is_assistant: bool + emojis: Optional[dict[str, int]] = None + user_emojis: Optional[list[str]] = None class Conversation(BaseModel): @@ -80,7 +82,6 @@ class Conversation(BaseModel): class Message(ConversationMessage): parent_id: Optional[UUID] = None created_date: Optional[datetime] = None - emojis: Optional[dict] = None class MessagePage(PageResult): @@ -223,27 +224,34 @@ class LabelTaskMode(str, enum.Enum): full = "full" -class LabelInitialPromptTask(Task): - """A task to label an initial prompt.""" +class LabelTaskDisposition(str, enum.Enum): + """Reason why the task was issued.""" - type: Literal["label_initial_prompt"] = "label_initial_prompt" + quality = "quality" + spam = "spam" + + +class AbstractLabelTask(Task): message_id: UUID - prompt: str valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] + disposition: Optional[LabelTaskDisposition] -class LabelConversationReplyTask(Task): +class LabelInitialPromptTask(AbstractLabelTask): + """A task to label an initial prompt.""" + + type: Literal["label_initial_prompt"] = "label_initial_prompt" + prompt: str + + +class LabelConversationReplyTask(AbstractLabelTask): """A task to label a reply to a conversation.""" type: Literal["label_conversation_reply"] = "label_conversation_reply" conversation: Conversation # the conversation so far - message_id: UUID reply: str - valid_labels: list[str] - mandatory_labels: Optional[list[str]] - mode: Optional[LabelTaskMode] class LabelPrompterReplyTask(LabelConversationReplyTask):