Add user emoji augmentation for message queries (#937)

* add disposition to text labeling tasks

* add emoji stats to ConversationMessage

* add user emoji augmentation for message queries

* add auth_method,username to message queries (query emoji status)

* add auth_method+username for single message

* fix param name typo

* only join rows when message.emojis != JSON.NULL

* formatting

* make sure emojis and user_emojis default to {}, []

* remove init_user(), use fresh empty default collections
This commit is contained in:
Andreas Köpf
2023-01-26 15:29:54 +01:00
committed by GitHub
parent 5d4f74f9d6
commit c2fa476904
8 changed files with 149 additions and 40 deletions
@@ -77,7 +77,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
+44 -16
View File
@@ -34,7 +34,7 @@ def query_messages(
"""
Query messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
@@ -93,7 +93,7 @@ def get_messages_cursor(
qry_max_count = max_count + 1 if before is None or after is None else max_count
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id)
items = pr.query_messages_ordered_by_created_date(
user_id=user_id,
auth_method=auth_method,
@@ -137,37 +137,49 @@ def get_messages_cursor(
@router.get("/{message_id}", response_model=protocol.Message)
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -175,24 +187,32 @@ def get_tree(
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -200,12 +220,16 @@ def get_descendants(
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -213,12 +237,16 @@ def get_longest_conv(
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
message_id: UUID,
auth_method: Optional[str] = None,
username: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
+1 -1
View File
@@ -230,7 +230,7 @@ def query_user_messages(
"""
Query user messages.
"""
pr = PromptRepository(db, api_client)
pr = PromptRepository(db, api_client, user_id=user_id)
messages = pr.query_messages_ordered_by_created_date(
user_id=user_id,
api_client_id=api_client_id,
+4 -1
View File
@@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message:
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
emojis=m.emojis,
emojis=m.emojis or {},
user_emojis=m.user_emojis or [],
)
@@ -30,6 +31,8 @@ def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.
text=message.text,
lang=message.lang,
is_assistant=(message.role == "assistant"),
emojis=message.emojis or {},
user_emojis=message.user_emojis or [],
)
for message in messages
]
+15 -2
View File
@@ -1,12 +1,13 @@
from datetime import datetime
from http import HTTPStatus
from typing import Optional
from typing import Any, Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from pydantic import PrivateAttr
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
@@ -17,6 +18,13 @@ class Message(SQLModel, table=True):
__tablename__ = "message"
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
def __new__(cls, *args: Any, **kwargs: Any):
new_object = super().__new__(cls, *args, **kwargs)
# temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged
if not hasattr(new_object, "_user_emojis"):
new_object._init_private_attributes()
return new_object
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
@@ -49,7 +57,8 @@ class Message(SQLModel, table=True):
rank: Optional[int] = Field(nullable=True)
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
@@ -59,3 +68,7 @@ class Message(SQLModel, table=True):
def text(self) -> str:
self.ensure_is_message()
return self.payload.payload.text
@property
def user_emojis(self) -> str:
return self._user_emojis
+58 -9
View File
@@ -30,8 +30,9 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import unaware_to_utc
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, and_, func, not_, or_, text, update
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -41,14 +42,25 @@ class PromptRepository:
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User] = None,
*,
user_repository: Optional[UserRepository] = None,
task_repository: Optional[TaskRepository] = None,
user_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
username: Optional[str] = None,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
if user_id:
self.user = self.user_repository.get_user(id=user_id)
self.user_id = self.user.id
elif auth_method and username:
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
self.user_id = self.user.id
else:
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
@@ -529,7 +541,7 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if not include_deleted:
qry = qry.filter(not_(Message.deleted))
return qry.all()
return self._add_user_emojis_all(qry)
def fetch_user_message_trees(
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
@@ -539,7 +551,7 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if not include_deleted:
qry = qry.filter(not_(Message.deleted))
return qry.all()
return self._add_user_emojis_all(qry)
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
qry = self.db.query(MessageTreeState).filter(
@@ -582,6 +594,10 @@ class PromptRepository:
return conversation, replies
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
qry = self.db.query(Message).filter(Message.id == message_id)
messages = self._add_user_emojis_all(qry)
message = messages[0] if messages else None
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
@@ -656,7 +672,7 @@ class PromptRepository:
qry = qry.filter(Message.review_result)
if exclude_deleted:
qry = qry.filter(Message.deleted == sa.false())
children = qry.all()
children = self._add_user_emojis_all(qry)
return children
def fetch_message_siblings(
@@ -674,7 +690,7 @@ class PromptRepository:
qry = qry.filter(Message.review_result == reviewed)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
siblings = qry.all()
siblings = self._add_user_emojis_all(qry)
return siblings
@staticmethod
@@ -705,7 +721,7 @@ class PromptRepository:
if max_depth is not None:
desc = desc.filter(Message.depth <= max_depth)
desc = desc.all()
desc = self._add_user_emojis_all(desc)
return self.trace_descendants(message, desc)
@@ -719,6 +735,33 @@ class PromptRepository:
max_message = max(tree, key=lambda m: m.children_count)
return max_message, [m for m in tree if m.parent_id == max_message.id]
def _add_user_emojis_all(self, qry: Query) -> list[Message]:
if self.user_id is None:
return qry.all()
sq = qry.subquery("m")
qry = (
self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis"))
.select_entity_from(sq)
.outerjoin(
MessageEmoji,
and_(
sq.c.id == MessageEmoji.message_id,
MessageEmoji.user_id == self.user_id,
sq.c.emojis != JSON.NULL,
),
)
.group_by(sq)
)
messages: list[Message] = []
for x in qry:
m: Message = x.Message
user_emojis = x["user_emojis"]
if user_emojis:
m._user_emojis = user_emojis.split(",")
messages.append(m)
return messages
def query_messages_ordered_by_created_date(
self,
user_id: Optional[UUID] = None,
@@ -801,7 +844,7 @@ class PromptRepository:
if lang is not None:
qry = qry.filter(Message.lang == lang)
return qry.all()
return self._add_user_emojis_all(qry)
def update_children_counts(self, message_tree_id: UUID):
sql_update_children_count = """
@@ -902,9 +945,15 @@ WHERE message.id = cc.id;
else:
count = emoji_counts.get(emoji.value) or 0
emoji_counts[emoji.value] = count + 1
if message._user_emojis is None:
message._user_emojis = []
if emoji.value not in message._user_emojis:
message._user_emojis.append(emoji.value)
elif op == protocol_schema.EmojiOp.remove:
# remove emoji record and & decrement count
message = self.fetch_message(message_id)
if message._user_emojis and emoji.value in message._user_emojis:
message._user_emojis.remove(emoji.value)
self.db.delete(existing_emoji)
emoji_counts = message.emojis
count = emoji_counts.get(emoji.value)
+8
View File
@@ -354,6 +354,7 @@ class TreeManager:
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
if message.role == "assistant":
@@ -363,6 +364,8 @@ class TreeManager:
):
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
task = protocol_schema.LabelAssistantReplyTask(
message_id=message.id,
@@ -371,6 +374,7 @@ class TreeManager:
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
disposition=label_disposition,
)
else:
if (
@@ -387,6 +391,7 @@ class TreeManager:
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
disposition=label_disposition,
)
parent_message_id = message.id
@@ -424,11 +429,13 @@ class TreeManager:
message = random.choice(prompts_need_review)
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
valid_labels = self._all_text_labels
if random.random() > self.cfg.p_full_labeling_review_prompt:
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt))
label_mode = protocol_schema.LabelTaskMode.simple
label_disposition = protocol_schema.LabelTaskDisposition.spam
logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).")
task = protocol_schema.LabelInitialPromptTask(
@@ -437,6 +444,7 @@ class TreeManager:
valid_labels=valid_labels,
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
disposition=label_disposition,
)
parent_message_id = message.id
+18 -10
View File
@@ -57,6 +57,8 @@ class ConversationMessage(BaseModel):
text: str
lang: Optional[str] # BCP 47
is_assistant: bool
emojis: Optional[dict[str, int]] = None
user_emojis: Optional[list[str]] = None
class Conversation(BaseModel):
@@ -80,7 +82,6 @@ class Conversation(BaseModel):
class Message(ConversationMessage):
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
emojis: Optional[dict] = None
class MessagePage(PageResult):
@@ -223,27 +224,34 @@ class LabelTaskMode(str, enum.Enum):
full = "full"
class LabelInitialPromptTask(Task):
"""A task to label an initial prompt."""
class LabelTaskDisposition(str, enum.Enum):
"""Reason why the task was issued."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
quality = "quality"
spam = "spam"
class AbstractLabelTask(Task):
message_id: UUID
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
disposition: Optional[LabelTaskDisposition]
class LabelConversationReplyTask(Task):
class LabelInitialPromptTask(AbstractLabelTask):
"""A task to label an initial prompt."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
prompt: str
class LabelConversationReplyTask(AbstractLabelTask):
"""A task to label a reply to a conversation."""
type: Literal["label_conversation_reply"] = "label_conversation_reply"
conversation: Conversation # the conversation so far
message_id: UUID
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[LabelTaskMode]
class LabelPrompterReplyTask(LabelConversationReplyTask):