diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index c53f6bae..85ab74ca 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -24,7 +24,8 @@ def get_message_by_frontend_id( message = pr.fetch_message_by_frontend_message_id(message_id) if not isinstance(message.payload.payload, MessagePayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) + # Unexpected message payload + raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index f5df2b24..35048331 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -18,7 +18,7 @@ router = APIRouter() def query_frontend_user_messages( username: str, api_client_id: UUID = None, - max_count: int = Query(10, gt=0, le=25), + max_count: int = Query(10, gt=0, le=1000), start_date: datetime.datetime = None, end_date: datetime.datetime = None, only_roots: bool = False, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index a812ef61..9221afc3 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -21,7 +21,7 @@ router = APIRouter() def query_messages( username: str = None, api_client_id: str = None, - max_count: int = Query(10, gt=0, le=25), + max_count: int = Query(10, gt=0, le=1000), start_date: datetime.datetime = None, end_date: datetime.datetime = None, only_roots: bool = False, @@ -63,7 +63,8 @@ def get_message( pr = PromptRepository(db, api_client, user=None) message = pr.fetch_message(message_id) if not isinstance(message.payload.payload, MessagePayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) + # Unexptcted message payload + raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py index 789e53f9..831d4df2 100644 --- a/backend/oasst_backend/api/v1/stats.py +++ b/backend/oasst_backend/api/v1/stats.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.schemas import protocol from sqlmodel import Session router = APIRouter() @@ -15,4 +14,4 @@ def get_message_stats( api_client: ApiClient = Depends(deps.get_trusted_api_client), ): pr = PromptRepository(db, api_client, None) - return protocol.SystemStats(**pr.get_stats()) + return pr.get_stats() diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 000b0970..0bac4d6a 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -18,7 +18,7 @@ router = APIRouter() def query_user_messages( user_id: UUID, api_client_id: UUID = None, - max_count: int = Query(10, gt=0, le=25), + max_count: int = Query(10, gt=0, le=1000), start_date: datetime.datetime = None, end_date: datetime.datetime = None, only_roots: bool = False, diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index 7f88caed..f431b05b 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -36,6 +36,7 @@ class OasstErrorCode(IntEnum): USER_NOT_SPECIFIED = 2005 NO_MESSAGE_TREE_FOUND = 2006 NO_REPLIES_FOUND = 2007 + INVALID_MESSAGE = 2008 TASK_NOT_FOUND = 2100 TASK_EXPIRED = 2101 TASK_PAYLOAD_TYPE_MISMATCH = 2102 diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 155a6d2a..b95f9d53 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -13,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import SystemStats from sqlalchemy import update from sqlmodel import Session, func from starlette.status import HTTP_403_FORBIDDEN @@ -57,6 +58,7 @@ class PromptRepository: return user def validate_frontend_message_id(self, message_id: str) -> None: + # TODO: Should it be replaced with fastapi/pydantic validation? if not isinstance(message_id, str): raise OasstError( f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID @@ -450,6 +452,13 @@ class PromptRepository: return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all() def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None): + """ + Fetch a conversation with multiple possible replies to it. + + This function finds a random message with >1 replies, + forms a conversation from the corresponding message tree root up to this message + and fetches up to max_size possible replies in continuation to this conversation. + """ parent = self.db.query(Message.id).filter(Message.children_count > 1) if message_role: parent = parent.filter(Message.role == message_role) @@ -483,6 +492,9 @@ class PromptRepository: return message def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): + """ + Mark task as done. No further messages will be accepted for this task. + """ self.validate_frontend_message_id(frontend_message_id) task = self.fetch_task_by_frontend_message_id(frontend_message_id) @@ -565,6 +577,11 @@ class PromptRepository: return list(_traverse_subtree(root)) def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]: + """ + Find all descendant messages to this message. + + This function creates a subtree of messages starting from given root message. + """ if isinstance(message, UUID): message = self.fetch_message(message) @@ -640,6 +657,9 @@ class PromptRepository: return messages.all() def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): + """ + Marks deleted messages and all their descendants. + """ if isinstance(messages, (Message, UUID)): messages = [messages] @@ -666,14 +686,19 @@ class PromptRepository: self.db.commit() - def get_stats(self): + def get_stats(self) -> SystemStats: + """ + Get data stats such as number of all messages in the system, + number of deleted and active messages and number of message trees. + """ deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted) nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None)) query = deleted.union_all(nthreads) result = {k: v for k, v in query.all()} - return { - "all": result.get(True, 0) + result.get(False, 0), - "active": result.get(False, 0), - "deleted": result.get(True, 0), - "threads": result.get(None, 0), - } + + return SystemStats( + all=result.get(True, 0) + result.get(False, 0), + active=result.get(False, 0), + deleted=result.get(True, 0), + message_trees=result.get(None, 0), + ) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index a2a12ddc..ed7dc780 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -277,4 +277,4 @@ class SystemStats(BaseModel): all: int = 0 active: int = 0 deleted: int = 0 - threads: int = 0 + message_trees: int = 0