mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
Duplicate message reply filter (#958)
* added changes for user specific message dumplication filter, added error codes, and settings variable as described in Draft PR #926, ran precommit * removed debug statements * add missing await to async_managed_tx_method * add 2nd missing await to async_managed_tx_method * added changes for user specific message dumplication filter, added error codes, and settings variable as described in Draft PR #926, ran precommit * removed debug statements * assert task user matches prompt_repository user * removed assert statments * moved duplicate_message_filter and message_length check into store_text_reply * removed old checks in tree_manager
This commit is contained in:
@@ -138,6 +138,8 @@ class Settings(BaseSettings):
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
@@ -9,6 +10,7 @@ import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
@@ -30,7 +32,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
@@ -188,6 +190,18 @@ class PromptRepository:
|
||||
role = None
|
||||
depth = 0
|
||||
|
||||
# reject whitespaces match with ^\s+$
|
||||
if re.match(r"^\s+$", text):
|
||||
raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
if self.check_users_recent_replies_for_duplicates(text):
|
||||
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
|
||||
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
|
||||
@@ -556,6 +570,30 @@ class PromptRepository:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def check_users_recent_replies_for_duplicates(self, text: str) -> bool:
|
||||
"""
|
||||
Checks if the user has recently replied with the same text within a given time period.
|
||||
"""
|
||||
|
||||
user_id = self.user_id
|
||||
logger.debug(f"Checking for duplicate tasks for user {user_id}")
|
||||
# messages in the past 24 hours
|
||||
messages = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.user_id == user_id)
|
||||
.order_by(Message.created_date.desc())
|
||||
.filter(
|
||||
Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not messages:
|
||||
return False
|
||||
for msg in messages:
|
||||
if msg.text == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fetch_user_message_trees(
|
||||
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
|
||||
) -> list[Message]:
|
||||
|
||||
@@ -520,14 +520,6 @@ class TreeManager:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# ensure message size is below the predefined limit
|
||||
if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT:
|
||||
logger.error(
|
||||
f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}."
|
||||
)
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
|
||||
@@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
TASK_AVAILABILITY_QUERY_FAILED = 1007
|
||||
TASK_MESSAGE_TOO_LONG = 1008
|
||||
TASK_MESSAGE_DUPLICATED = 1009
|
||||
TASK_MESSAGE_TEXT_EMPTY = 1010
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
|
||||
Reference in New Issue
Block a user