diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 157845d7..92316037 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -138,6 +138,8 @@ class Settings(BaseSettings): DEBUG_SKIP_TOXICITY_CALCULATION: bool = False DEBUG_DATABASE_ECHO: bool = False + DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120 + HUGGING_FACE_API_KEY: str = "" ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8e2c746d..01f63d26 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -1,6 +1,7 @@ import random +import re from collections import defaultdict -from datetime import datetime +from datetime import datetime, timedelta from http import HTTPStatus from typing import Optional from uuid import UUID, uuid4 @@ -9,6 +10,7 @@ import oasst_backend.models.db_payload as db_payload import sqlalchemy as sa from loguru import logger from oasst_backend.api.deps import FrontendUserId +from oasst_backend.config import settings from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ( ApiClient, @@ -30,7 +32,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats -from oasst_shared.utils import unaware_to_utc +from oasst_shared.utils import unaware_to_utc, utcnow from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update @@ -188,6 +190,18 @@ class PromptRepository: role = None depth = 0 + # reject whitespaces match with ^\s+$ + if re.match(r"^\s+$", text): + raise OasstError("Message text is empty", OasstErrorCode.TASK_MESSAGE_TEXT_EMPTY) + + # ensure message size is below the predefined limit + if len(text) > settings.MESSAGE_SIZE_LIMIT: + logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.") + raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG) + + if self.check_users_recent_replies_for_duplicates(text): + raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED) + if task.parent_message_id: parent_message = self.fetch_message(task.parent_message_id) @@ -556,6 +570,30 @@ class PromptRepository: qry = qry.filter(not_(Message.deleted)) return self._add_user_emojis_all(qry) + def check_users_recent_replies_for_duplicates(self, text: str) -> bool: + """ + Checks if the user has recently replied with the same text within a given time period. + """ + + user_id = self.user_id + logger.debug(f"Checking for duplicate tasks for user {user_id}") + # messages in the past 24 hours + messages = ( + self.db.query(Message) + .filter(Message.user_id == user_id) + .order_by(Message.created_date.desc()) + .filter( + Message.created_date > utcnow() - timedelta(minutes=settings.DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES) + ) + .all() + ) + if not messages: + return False + for msg in messages: + if msg.text == text: + return True + return False + def fetch_user_message_trees( self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False ) -> list[Message]: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 909bb84b..52aff24e 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -520,14 +520,6 @@ class TreeManager: logger.info( f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}." ) - - # ensure message size is below the predefined limit - if len(interaction.text) > settings.MESSAGE_SIZE_LIMIT: - logger.error( - f"Message size {len(interaction.text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}." - ) - raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG) - # here we store the text reply in the database message = pr.store_text_reply( text=interaction.text, diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 9764062e..e7f5dcd5 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -38,6 +38,8 @@ class OasstErrorCode(IntEnum): TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006 TASK_AVAILABILITY_QUERY_FAILED = 1007 TASK_MESSAGE_TOO_LONG = 1008 + TASK_MESSAGE_DUPLICATED = 1009 + TASK_MESSAGE_TEXT_EMPTY = 1010 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000