From 66b7ed2b9beba9a504792e8b7aab33494df78ff6 Mon Sep 17 00:00:00 2001 From: Graeme Harris Date: Tue, 7 Feb 2023 23:39:24 +0200 Subject: [PATCH] 968 add flagged message table and endpoints (#1325) * Added flagged message table * Added alembic migration and updated imports to match style * Added GET endpoint to query all flagged messages * Updates from linter * Added POST endpoint for processing flagged messages * Added pydantic interface model and fixed limit update bug * fixed session in admin endpoint and added require session refresh for returned update * removed unused import --- ...bc_added_new_table_for_flagged_messages.py | 41 +++++++++++++++++++ backend/oasst_backend/api/v1/admin.py | 35 ++++++++++++++++ backend/oasst_backend/models/__init__.py | 2 + .../oasst_backend/models/flagged_message.py | 23 +++++++++++ backend/oasst_backend/prompt_repository.py | 31 ++++++++++++++ 5 files changed, 132 insertions(+) create mode 100644 backend/alembic/versions/2023_02_07_1922-caee1e8ee0bc_added_new_table_for_flagged_messages.py create mode 100644 backend/oasst_backend/models/flagged_message.py diff --git a/backend/alembic/versions/2023_02_07_1922-caee1e8ee0bc_added_new_table_for_flagged_messages.py b/backend/alembic/versions/2023_02_07_1922-caee1e8ee0bc_added_new_table_for_flagged_messages.py new file mode 100644 index 00000000..674313b5 --- /dev/null +++ b/backend/alembic/versions/2023_02_07_1922-caee1e8ee0bc_added_new_table_for_flagged_messages.py @@ -0,0 +1,41 @@ +"""Added new table for flagged messages + +Revision ID: caee1e8ee0bc +Revises: 8c8241d1f973 +Create Date: 2023-02-07 19:22:12.696257 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "caee1e8ee0bc" +down_revision = "8c8241d1f973" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "flagged_message", + sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column( + "created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column("processed", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_index(op.f("ix_flagged_message_created_date"), "flagged_message", ["created_date"], unique=False) + op.create_index(op.f("ix_flagged_message_processed"), "flagged_message", ["processed"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_flagged_message_processed"), table_name="flagged_message") + op.drop_index(op.f("ix_flagged_message_created_date"), table_name="flagged_message") + op.drop_table("flagged_message") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py index 584c5fb7..4974782d 100644 --- a/backend/oasst_backend/api/v1/admin.py +++ b/backend/oasst_backend/api/v1/admin.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from uuid import UUID import pydantic @@ -162,3 +163,37 @@ async def purge_user_messages( logger.info(f"{before=}; {after=}") return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed) + + +class FlaggedMessageResponse(pydantic.BaseModel): + message_id: UUID + processed: bool + created_date: Optional[datetime] + + +@router.get("/flagged_messages", response_model=list[FlaggedMessageResponse]) +async def get_flagged_messages( + max_count: Optional[int], + session: deps.Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> str: + assert api_client.trusted + + pr = PromptRepository(session, api_client) + flagged_messages = pr.fetch_flagged_messages(max_count=max_count) + resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages] + return resp + + +@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse) +async def process_flagged_messages( + message_id: UUID, + session: deps.Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> str: + assert api_client.trusted + + pr = PromptRepository(session, api_client) + flagged_msg = pr.process_flagged_message(message_id=message_id) + resp = FlaggedMessageResponse(**flagged_msg.__dict__) + return resp diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 65594dde..62d38e0a 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -1,4 +1,5 @@ from .api_client import ApiClient +from .flagged_message import FlaggedMessage from .journal import Journal, JournalIntegration from .message import Message from .message_embedding import MessageEmbedding @@ -28,4 +29,5 @@ __all__ = [ "JournalIntegration", "MessageEmoji", "TrollStats", + "FlaggedMessage", ] diff --git a/backend/oasst_backend/models/flagged_message.py b/backend/oasst_backend/models/flagged_message.py new file mode 100644 index 00000000..121c7034 --- /dev/null +++ b/backend/oasst_backend/models/flagged_message.py @@ -0,0 +1,23 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, SQLModel + + +class FlaggedMessage(SQLModel, table=True): + __tablename__ = "flagged_message" + + message_id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True + ) + ) + processed: bool = Field(nullable=False, index=True) + created_date: Optional[datetime] = Field( + sa_column=sa.Column( + sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True + ) + ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index c5fb0208..958eaeba 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -7,12 +7,14 @@ from typing import Optional from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload +import sqlalchemy.dialects.postgresql as pg from loguru import logger from oasst_backend.api.deps import FrontendUserId from oasst_backend.config import settings from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ( ApiClient, + FlaggedMessage, Message, MessageEmbedding, MessageEmoji, @@ -1092,6 +1094,15 @@ WHERE message.id = cc.id; logger.debug(f"Ignoring add emoji op for user's own message ({emoji=})") return message + # Add to flagged_message table if the red flag emoji is applied + if emoji == protocol_schema.EmojiCode.red_flag: + flagged_message = FlaggedMessage( + message_id=message_id, processed=False, created_date=datetime.now().astimezone() + ) + insert_stmt = pg.insert(FlaggedMessage).values(**flagged_message.__dict__) + upsert_stmt = insert_stmt.on_conflict_do_update(constraint="message_id", set_=flagged_message.__dict__) + self.db.execute(upsert_stmt) + # insert emoji record & increment count message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji) self.db.add(message_emoji) @@ -1127,3 +1138,23 @@ WHERE message.id = cc.id; self.db.add(message) self.db.flush() return message + + def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessage]: + qry = self.db.query(FlaggedMessage) + if max_count is not None: + qry = qry.limit(max_count) + + return qry.all() + + def process_flagged_message(self, message_id: UUID) -> FlaggedMessage: + + message = self.db.query(FlaggedMessage).get(message_id) + + if not message: + raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND) + + message.processed = True + self.db.commit() + self.db.refresh(message) + + return message