Add /messages/{message_id}/emoji endpoint to toggle, add, remove message emojis (#925)

* add endpoint to set message emojis

* make refresh result optional in db utils
This commit is contained in:
Andreas Köpf
2023-01-25 09:31:20 +01:00
committed by GitHub
parent 4146930eb9
commit 558b207013
10 changed files with 187 additions and 0 deletions
@@ -0,0 +1,44 @@
"""add message_emoji
Revision ID: 40ed93df0ed5
Revises: 8ba17b5f467a
Create Date: 2023-01-24 22:56:28.229408
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "40ed93df0ed5"
down_revision = "8ba17b5f467a"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_emoji",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"),
)
op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False)
op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "emojis")
op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji")
op.drop_table("message_emoji")
# ### end Alembic commands ###
+20
View File
@@ -7,6 +7,7 @@ from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
@@ -229,3 +230,22 @@ def mark_message_deleted(
):
pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)
@router.post("/{message_id}/emoji", response_model=protocol.Message)
def post_message_emoji(
*,
message_id: UUID,
request: protocol.MessageEmojiRequest,
api_client: ApiClient = Depends(deps.get_api_client),
) -> protocol.Message:
"""
Toggle, add or remove message emoji.
"""
@managed_tx_function(CommitMode.COMMIT)
def emoji_tx(session: deps.Session):
pr = PromptRepository(session, api_client, client_user=request.user)
return pr.handle_message_emoji(message_id, request.op, request.emoji)
return utils.prepare_message(emoji_tx())
+1
View File
@@ -14,6 +14,7 @@ def prepare_message(m: Message) -> protocol.Message:
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
emojis=m.emojis,
)
+2
View File
@@ -2,6 +2,7 @@ from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_emoji import MessageEmoji
from .message_reaction import MessageReaction
from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
@@ -24,4 +25,5 @@ __all__ = [
"TextLabels",
"Journal",
"JournalIntegration",
"MessageEmoji",
]
+2
View File
@@ -49,6 +49,8 @@ class Message(SQLModel, table=True):
rank: Optional[int] = Field(nullable=True)
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
@@ -0,0 +1,27 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class MessageEmoji(SQLModel, table=True):
__tablename__ = "message_emoji"
__table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),)
message_id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
user_id: UUID = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
emoji: str = Field(nullable=False, max_length=128, primary_key=True)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
@@ -13,6 +13,7 @@ from oasst_backend.models import (
ApiClient,
Message,
MessageEmbedding,
MessageEmoji,
MessageReaction,
MessageToxicity,
MessageTreeState,
@@ -29,6 +30,7 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import unaware_to_utc
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, and_, func, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -843,3 +845,62 @@ WHERE message.id = cc.id;
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
self.ensure_user_is_enabled()
message = self.fetch_message(message_id)
# check if emoji exists
existing_emoji = (
self.db.query(MessageEmoji)
.filter(
MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji
)
.one_or_none()
)
if existing_emoji:
if op == protocol_schema.EmojiOp.add:
logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}")
return message
elif op == protocol_schema.EmojiOp.togggle:
op = protocol_schema.EmojiOp.remove
if existing_emoji is None:
if op == protocol_schema.EmojiOp.remove:
logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}")
return message
elif op == protocol_schema.EmojiOp.togggle:
op = protocol_schema.EmojiOp.add
if op == protocol_schema.EmojiOp.add:
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
emoji_counts = message.emojis
if not emoji_counts:
message.emojis = {emoji.value: 1}
else:
count = emoji_counts.get(emoji.value) or 0
emoji_counts[emoji.value] = count + 1
elif op == protocol_schema.EmojiOp.remove:
# remove emoji record and & decrement count
message = self.fetch_message(message_id)
self.db.delete(existing_emoji)
emoji_counts = message.emojis
count = emoji_counts.get(emoji.value)
if count is not None:
if count == 1:
del emoji_counts[emoji.value]
else:
emoji_counts[emoji.value] = count - 1
flag_modified(message, "emojis")
self.db.add(message)
else:
raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED)
flag_modified(message, "emojis")
self.db.add(message)
self.db.flush()
return message
@@ -107,6 +107,7 @@ def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
refresh_result: bool = True,
):
"""Passes Session object as first argument to wrapped function."""
@@ -124,6 +125,8 @@ def managed_tx_function(
session.flush()
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
if refresh_result and isinstance(result, SQLModel):
session.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
@@ -76,6 +76,8 @@ class OasstErrorCode(IntEnum):
USER_DISABLED = 4001
USER_NOT_FOUND = 4002
EMOJI_OP_UNSUPPORTED = 5000
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
@@ -80,6 +80,7 @@ class Conversation(BaseModel):
class Message(ConversationMessage):
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
emojis: Optional[dict] = None
class MessagePage(PageResult):
@@ -432,3 +433,27 @@ class OasstErrorResponse(BaseModel):
error_code: OasstErrorCode
message: str
class EmojiCode(str, enum.Enum):
thumbs_up = "+1" # 👍
thumbs_down = "-1" # 👎
red_flag = "red_flag" # 🚩
hundred = "100" # 💯
rofl = "rofl" # 🤣"
heart_eyes = "heart_eyes" # 😍
disappointed = "disappointed" # 😞
poop = "poop" # 💩
skull = "skull" # 💀
class EmojiOp(str, enum.Enum):
togggle = "toggle"
add = "add"
remove = "remove"
class MessageEmojiRequest(BaseModel):
user: User
op: EmojiOp = EmojiOp.togggle
emoji: EmojiCode