From 558b2070134ee2b54e2364d55f6a4ca2f7e2ce03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 25 Jan 2023 09:31:20 +0100 Subject: [PATCH] Add /messages/{message_id}/emoji endpoint to toggle, add, remove message emojis (#925) * add endpoint to set message emojis * make refresh result optional in db utils --- ..._24_2256-40ed93df0ed5_add_message_emoji.py | 44 +++++++++++++ backend/oasst_backend/api/v1/messages.py | 20 ++++++ backend/oasst_backend/api/v1/utils.py | 1 + backend/oasst_backend/models/__init__.py | 2 + backend/oasst_backend/models/message.py | 2 + backend/oasst_backend/models/message_emoji.py | 27 ++++++++ backend/oasst_backend/prompt_repository.py | 61 +++++++++++++++++++ backend/oasst_backend/utils/database_utils.py | 3 + .../exceptions/oasst_api_error.py | 2 + oasst-shared/oasst_shared/schemas/protocol.py | 25 ++++++++ 10 files changed, 187 insertions(+) create mode 100644 backend/alembic/versions/2023_01_24_2256-40ed93df0ed5_add_message_emoji.py create mode 100644 backend/oasst_backend/models/message_emoji.py diff --git a/backend/alembic/versions/2023_01_24_2256-40ed93df0ed5_add_message_emoji.py b/backend/alembic/versions/2023_01_24_2256-40ed93df0ed5_add_message_emoji.py new file mode 100644 index 00000000..17368c0d --- /dev/null +++ b/backend/alembic/versions/2023_01_24_2256-40ed93df0ed5_add_message_emoji.py @@ -0,0 +1,44 @@ +"""add message_emoji + +Revision ID: 40ed93df0ed5 +Revises: 8ba17b5f467a +Create Date: 2023-01-24 22:56:28.229408 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "40ed93df0ed5" +down_revision = "8ba17b5f467a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "message_emoji", + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + "created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"), + ) + op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False) + op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "emojis") + op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji") + op.drop_table("message_emoji") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 06dd3fe1..af3ae42d 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -7,6 +7,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol from sqlmodel import Session @@ -229,3 +230,22 @@ def mark_message_deleted( ): pr = PromptRepository(db, api_client) pr.mark_messages_deleted(message_id) + + +@router.post("/{message_id}/emoji", response_model=protocol.Message) +def post_message_emoji( + *, + message_id: UUID, + request: protocol.MessageEmojiRequest, + api_client: ApiClient = Depends(deps.get_api_client), +) -> protocol.Message: + """ + Toggle, add or remove message emoji. + """ + + @managed_tx_function(CommitMode.COMMIT) + def emoji_tx(session: deps.Session): + pr = PromptRepository(session, api_client, client_user=request.user) + return pr.handle_message_emoji(message_id, request.op, request.emoji) + + return utils.prepare_message(emoji_tx()) diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 99161e32..8b0f378f 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -14,6 +14,7 @@ def prepare_message(m: Message) -> protocol.Message: lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, + emojis=m.emojis, ) diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 2b30b475..420c0ccd 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -2,6 +2,7 @@ from .api_client import ApiClient from .journal import Journal, JournalIntegration from .message import Message from .message_embedding import MessageEmbedding +from .message_emoji import MessageEmoji from .message_reaction import MessageReaction from .message_toxicity import MessageToxicity from .message_tree_state import MessageTreeState @@ -24,4 +25,5 @@ __all__ = [ "TextLabels", "Journal", "JournalIntegration", + "MessageEmoji", ] diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index d0b1d869..da0c06c3 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -49,6 +49,8 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) + emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) + def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/backend/oasst_backend/models/message_emoji.py b/backend/oasst_backend/models/message_emoji.py new file mode 100644 index 00000000..9e6e92fb --- /dev/null +++ b/backend/oasst_backend/models/message_emoji.py @@ -0,0 +1,27 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, Index, SQLModel + + +class MessageEmoji(SQLModel, table=True): + __tablename__ = "message_emoji" + __table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),) + + message_id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True + ) + ) + user_id: UUID = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True + ) + ) + emoji: str = Field(nullable=False, max_length=128, primary_key=True) + created_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) + ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index f9ba9ad6..bbc8abe2 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -13,6 +13,7 @@ from oasst_backend.models import ( ApiClient, Message, MessageEmbedding, + MessageEmoji, MessageReaction, MessageToxicity, MessageTreeState, @@ -29,6 +30,7 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats from oasst_shared.utils import unaware_to_utc +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, and_, func, not_, or_, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -843,3 +845,62 @@ WHERE message.id = cc.id; deleted=result.get(True, 0), message_trees=result.get(None, 0), ) + + def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message: + self.ensure_user_is_enabled() + + message = self.fetch_message(message_id) + + # check if emoji exists + existing_emoji = ( + self.db.query(MessageEmoji) + .filter( + MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji + ) + .one_or_none() + ) + + if existing_emoji: + if op == protocol_schema.EmojiOp.add: + logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}") + return message + elif op == protocol_schema.EmojiOp.togggle: + op = protocol_schema.EmojiOp.remove + + if existing_emoji is None: + if op == protocol_schema.EmojiOp.remove: + logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}") + return message + elif op == protocol_schema.EmojiOp.togggle: + op = protocol_schema.EmojiOp.add + + if op == protocol_schema.EmojiOp.add: + # insert emoji record & increment count + message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji) + self.db.add(message_emoji) + emoji_counts = message.emojis + if not emoji_counts: + message.emojis = {emoji.value: 1} + else: + count = emoji_counts.get(emoji.value) or 0 + emoji_counts[emoji.value] = count + 1 + elif op == protocol_schema.EmojiOp.remove: + # remove emoji record and & decrement count + message = self.fetch_message(message_id) + self.db.delete(existing_emoji) + emoji_counts = message.emojis + count = emoji_counts.get(emoji.value) + if count is not None: + if count == 1: + del emoji_counts[emoji.value] + else: + emoji_counts[emoji.value] = count - 1 + flag_modified(message, "emojis") + self.db.add(message) + else: + raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED) + + flag_modified(message, "emojis") + self.db.add(message) + self.db.flush() + return message diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 34113cfc..fb8bf6c5 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -107,6 +107,7 @@ def managed_tx_function( auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, session_factory: Callable[..., Session] = default_session_factor, + refresh_result: bool = True, ): """Passes Session object as first argument to wrapped function.""" @@ -124,6 +125,8 @@ def managed_tx_function( session.flush() elif auto_commit == CommitMode.ROLLBACK: session.rollback() + if refresh_result and isinstance(result, SQLModel): + session.refresh(result) return result except OperationalError: logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 0a548ebb..9764062e 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -76,6 +76,8 @@ class OasstErrorCode(IntEnum): USER_DISABLED = 4001 USER_NOT_FOUND = 4002 + EMOJI_OP_UNSUPPORTED = 5000 + class OasstError(Exception): """Base class for Open-Assistant exceptions.""" diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 20bbdf9b..bb54b502 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -80,6 +80,7 @@ class Conversation(BaseModel): class Message(ConversationMessage): parent_id: Optional[UUID] = None created_date: Optional[datetime] = None + emojis: Optional[dict] = None class MessagePage(PageResult): @@ -432,3 +433,27 @@ class OasstErrorResponse(BaseModel): error_code: OasstErrorCode message: str + + +class EmojiCode(str, enum.Enum): + thumbs_up = "+1" # 👍 + thumbs_down = "-1" # 👎 + red_flag = "red_flag" # 🚩 + hundred = "100" # 💯 + rofl = "rofl" # 🤣" + heart_eyes = "heart_eyes" # 😍 + disappointed = "disappointed" # 😞 + poop = "poop" # 💩 + skull = "skull" # 💀 + + +class EmojiOp(str, enum.Enum): + togggle = "toggle" + add = "add" + remove = "remove" + + +class MessageEmojiRequest(BaseModel): + user: User + op: EmojiOp = EmojiOp.togggle + emoji: EmojiCode