Duplicate message reply filter (#958)

* added changes for user specific message dumplication filter, added error codes, and settings variable as described in Draft PR #926, ran precommit

* removed debug statements

* add missing await to async_managed_tx_method

* add 2nd missing await to async_managed_tx_method

* added changes for user specific message dumplication filter, added error codes, and settings variable as described in Draft PR #926, ran precommit

* removed debug statements

* assert task user matches prompt_repository user

* removed assert statments

* moved duplicate_message_filter and message_length check into store_text_reply

* removed old checks in tree_manager
This commit is contained in:
dhug
2023-01-27 16:52:19 -05:00
committed by GitHub
parent 3a32a10b23
commit 49b5999ce6
4 changed files with 44 additions and 10 deletions
+2
View File
@@ -138,6 +138,8 @@ class Settings(BaseSettings):
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
HUGGING_FACE_API_KEY: str = ""
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
+40 -2
View File
@@ -1,6 +1,7 @@
import random
import re
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
@@ -9,6 +10,7 @@ import oasst_backend.models.db_payload as db_payload
import sqlalchemy as sa
from loguru import logger
from oasst_backend.api.deps import FrontendUserId
from oasst_backend.config import settings
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import (
ApiClient,
@@ -30,7 +32,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import unaware_to_utc
from oasst_shared.utils import unaware_to_utc, utcnow
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
@@ -188,6 +190,18 @@ class PromptRepository:
role = None
depth = 0
# reject whitespaces match with ^\s+$
if re.match(r"^\s+$", text):
raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY)
# ensure message size is below the predefined limit
if len(text) > settings.MESSAGE_SIZE_LIMIT:
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
if self.check_users_recent_replies_for_duplicates(text):
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
@@ -556,6 +570,30 @@ class PromptRepository:
qry = qry.filter(not_(Message.deleted))
return self._add_user_emojis_all(qry)
def check_users_recent_replies_for_duplicates(self, text: str) -> bool:
"""
Checks if the user has recently replied with the same text within a given time period.
"""
user_id = self.user_id
logger.debug(f"Checking for duplicate tasks for user {user_id}")
# messages in the past 24 hours
messages = (
self.db.query(Message)
.filter(Message.user_id == user_id)
.order_by(Message.created_date.desc())
.filter(
Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES)
)
.all()
)
if not messages:
return False
for msg in messages:
if msg.text == text:
return True
return False
def fetch_user_message_trees(
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
) -> list[Message]:
-8
View File
@@ -520,14 +520,6 @@ class TreeManager:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# ensure message size is below the predefined limit
if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT:
logger.error(
f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}."
)
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
@@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum):
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
TASK_AVAILABILITY_QUERY_FAILED = 1007
TASK_MESSAGE_TOO_LONG = 1008
TASK_MESSAGE_DUPLICATED = 1009
TASK_MESSAGE_TEXT_EMPTY = 1010
# 2000-3000: prompt_repository
INVALID_FRONTEND_MESSAGE_ID = 2000