mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add /messages/{message_id}/emoji endpoint to toggle, add, remove message emojis (#925)
* add endpoint to set message emojis * make refresh result optional in db utils
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
"""add message_emoji
|
||||
|
||||
Revision ID: 40ed93df0ed5
|
||||
Revises: 8ba17b5f467a
|
||||
Create Date: 2023-01-24 22:56:28.229408
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "40ed93df0ed5"
|
||||
down_revision = "8ba17b5f467a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"message_emoji",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"),
|
||||
)
|
||||
op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False)
|
||||
op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "emojis")
|
||||
op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji")
|
||||
op.drop_table("message_emoji")
|
||||
# ### end Alembic commands ###
|
||||
@@ -7,6 +7,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
@@ -229,3 +230,22 @@ def mark_message_deleted(
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
|
||||
@router.post("/{message_id}/emoji", response_model=protocol.Message)
|
||||
def post_message_emoji(
|
||||
*,
|
||||
message_id: UUID,
|
||||
request: protocol.MessageEmojiRequest,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
) -> protocol.Message:
|
||||
"""
|
||||
Toggle, add or remove message emoji.
|
||||
"""
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def emoji_tx(session: deps.Session):
|
||||
pr = PromptRepository(session, api_client, client_user=request.user)
|
||||
return pr.handle_message_emoji(message_id, request.op, request.emoji)
|
||||
|
||||
return utils.prepare_message(emoji_tx())
|
||||
|
||||
@@ -14,6 +14,7 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
lang=m.lang,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
emojis=m.emojis,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
from .message_emoji import MessageEmoji
|
||||
from .message_reaction import MessageReaction
|
||||
from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
@@ -24,4 +25,5 @@ __all__ = [
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
]
|
||||
|
||||
@@ -49,6 +49,8 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class MessageEmoji(SQLModel, table=True):
|
||||
__tablename__ = "message_emoji"
|
||||
__table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),)
|
||||
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
user_id: UUID = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
emoji: str = Field(nullable=False, max_length=128, primary_key=True)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from oasst_backend.models import (
|
||||
ApiClient,
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageEmoji,
|
||||
MessageReaction,
|
||||
MessageToxicity,
|
||||
MessageTreeState,
|
||||
@@ -29,6 +30,7 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
@@ -843,3 +845,62 @@ WHERE message.id = cc.id;
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
message = self.fetch_message(message_id)
|
||||
|
||||
# check if emoji exists
|
||||
existing_emoji = (
|
||||
self.db.query(MessageEmoji)
|
||||
.filter(
|
||||
MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if existing_emoji:
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}")
|
||||
return message
|
||||
elif op == protocol_schema.EmojiOp.togggle:
|
||||
op = protocol_schema.EmojiOp.remove
|
||||
|
||||
if existing_emoji is None:
|
||||
if op == protocol_schema.EmojiOp.remove:
|
||||
logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}")
|
||||
return message
|
||||
elif op == protocol_schema.EmojiOp.togggle:
|
||||
op = protocol_schema.EmojiOp.add
|
||||
|
||||
if op == protocol_schema.EmojiOp.add:
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
emoji_counts = message.emojis
|
||||
if not emoji_counts:
|
||||
message.emojis = {emoji.value: 1}
|
||||
else:
|
||||
count = emoji_counts.get(emoji.value) or 0
|
||||
emoji_counts[emoji.value] = count + 1
|
||||
elif op == protocol_schema.EmojiOp.remove:
|
||||
# remove emoji record and & decrement count
|
||||
message = self.fetch_message(message_id)
|
||||
self.db.delete(existing_emoji)
|
||||
emoji_counts = message.emojis
|
||||
count = emoji_counts.get(emoji.value)
|
||||
if count is not None:
|
||||
if count == 1:
|
||||
del emoji_counts[emoji.value]
|
||||
else:
|
||||
emoji_counts[emoji.value] = count - 1
|
||||
flag_modified(message, "emojis")
|
||||
self.db.add(message)
|
||||
else:
|
||||
raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED)
|
||||
|
||||
flag_modified(message, "emojis")
|
||||
self.db.add(message)
|
||||
self.db.flush()
|
||||
return message
|
||||
|
||||
@@ -107,6 +107,7 @@ def managed_tx_function(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT,
|
||||
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
|
||||
session_factory: Callable[..., Session] = default_session_factor,
|
||||
refresh_result: bool = True,
|
||||
):
|
||||
"""Passes Session object as first argument to wrapped function."""
|
||||
|
||||
@@ -124,6 +125,8 @@ def managed_tx_function(
|
||||
session.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
if refresh_result and isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
|
||||
@@ -76,6 +76,8 @@ class OasstErrorCode(IntEnum):
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
|
||||
EMOJI_OP_UNSUPPORTED = 5000
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
@@ -80,6 +80,7 @@ class Conversation(BaseModel):
|
||||
class Message(ConversationMessage):
|
||||
parent_id: Optional[UUID] = None
|
||||
created_date: Optional[datetime] = None
|
||||
emojis: Optional[dict] = None
|
||||
|
||||
|
||||
class MessagePage(PageResult):
|
||||
@@ -432,3 +433,27 @@ class OasstErrorResponse(BaseModel):
|
||||
|
||||
error_code: OasstErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
class EmojiCode(str, enum.Enum):
|
||||
thumbs_up = "+1" # 👍
|
||||
thumbs_down = "-1" # 👎
|
||||
red_flag = "red_flag" # 🚩
|
||||
hundred = "100" # 💯
|
||||
rofl = "rofl" # 🤣"
|
||||
heart_eyes = "heart_eyes" # 😍
|
||||
disappointed = "disappointed" # 😞
|
||||
poop = "poop" # 💩
|
||||
skull = "skull" # 💀
|
||||
|
||||
|
||||
class EmojiOp(str, enum.Enum):
|
||||
togggle = "toggle"
|
||||
add = "add"
|
||||
remove = "remove"
|
||||
|
||||
|
||||
class MessageEmojiRequest(BaseModel):
|
||||
user: User
|
||||
op: EmojiOp = EmojiOp.togggle
|
||||
emoji: EmojiCode
|
||||
|
||||
Reference in New Issue
Block a user