mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
968 add flagged message table and endpoints (#1325)
* Added flagged message table * Added alembic migration and updated imports to match style * Added GET endpoint to query all flagged messages * Updates from linter * Added POST endpoint for processing flagged messages * Added pydantic interface model and fixed limit update bug * fixed session in admin endpoint and added require session refresh for returned update * removed unused import
This commit is contained in:
+41
@@ -0,0 +1,41 @@
|
||||
"""Added new table for flagged messages
|
||||
|
||||
Revision ID: caee1e8ee0bc
|
||||
Revises: 8c8241d1f973
|
||||
Create Date: 2023-02-07 19:22:12.696257
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "caee1e8ee0bc"
|
||||
down_revision = "8c8241d1f973"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"flagged_message",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("processed", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("message_id"),
|
||||
)
|
||||
op.create_index(op.f("ix_flagged_message_created_date"), "flagged_message", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_flagged_message_processed"), "flagged_message", ["processed"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_flagged_message_processed"), table_name="flagged_message")
|
||||
op.drop_index(op.f("ix_flagged_message_created_date"), table_name="flagged_message")
|
||||
op.drop_table("flagged_message")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
@@ -162,3 +163,37 @@ async def purge_user_messages(
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
|
||||
class FlaggedMessageResponse(pydantic.BaseModel):
|
||||
message_id: UUID
|
||||
processed: bool
|
||||
created_date: Optional[datetime]
|
||||
|
||||
|
||||
@router.get("/flagged_messages", response_model=list[FlaggedMessageResponse])
|
||||
async def get_flagged_messages(
|
||||
max_count: Optional[int],
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
pr = PromptRepository(session, api_client)
|
||||
flagged_messages = pr.fetch_flagged_messages(max_count=max_count)
|
||||
resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages]
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
|
||||
async def process_flagged_messages(
|
||||
message_id: UUID,
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
pr = PromptRepository(session, api_client)
|
||||
flagged_msg = pr.process_flagged_message(message_id=message_id)
|
||||
resp = FlaggedMessageResponse(**flagged_msg.__dict__)
|
||||
return resp
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .api_client import ApiClient
|
||||
from .flagged_message import FlaggedMessage
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
@@ -28,4 +29,5 @@ __all__ = [
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
"TrollStats",
|
||||
"FlaggedMessage",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class FlaggedMessage(SQLModel, table=True):
|
||||
__tablename__ = "flagged_message"
|
||||
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
processed: bool = Field(nullable=False, index=True)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
|
||||
)
|
||||
)
|
||||
@@ -7,12 +7,14 @@ from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
FlaggedMessage,
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageEmoji,
|
||||
@@ -1092,6 +1094,15 @@ WHERE message.id = cc.id;
|
||||
logger.debug(f"Ignoring add emoji op for user's own message ({emoji=})")
|
||||
return message
|
||||
|
||||
# Add to flagged_message table if the red flag emoji is applied
|
||||
if emoji == protocol_schema.EmojiCode.red_flag:
|
||||
flagged_message = FlaggedMessage(
|
||||
message_id=message_id, processed=False, created_date=datetime.now().astimezone()
|
||||
)
|
||||
insert_stmt = pg.insert(FlaggedMessage).values(**flagged_message.__dict__)
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(constraint="message_id", set_=flagged_message.__dict__)
|
||||
self.db.execute(upsert_stmt)
|
||||
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
@@ -1127,3 +1138,23 @@ WHERE message.id = cc.id;
|
||||
self.db.add(message)
|
||||
self.db.flush()
|
||||
return message
|
||||
|
||||
def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessage]:
|
||||
qry = self.db.query(FlaggedMessage)
|
||||
if max_count is not None:
|
||||
qry = qry.limit(max_count)
|
||||
|
||||
return qry.all()
|
||||
|
||||
def process_flagged_message(self, message_id: UUID) -> FlaggedMessage:
|
||||
|
||||
message = self.db.query(FlaggedMessage).get(message_id)
|
||||
|
||||
if not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
|
||||
|
||||
message.processed = True
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
Reference in New Issue
Block a user