This commit is contained in:
Igor Miagkov
2022-12-31 05:14:21 +04:00
committed by Andreas Köpf
parent 6d98ba1f75
commit ef3a35ff9c
8 changed files with 42 additions and 15 deletions
@@ -24,7 +24,8 @@ def get_message_by_frontend_id(
message = pr.fetch_message_by_frontend_message_id(message_id)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@@ -18,7 +18,7 @@ router = APIRouter()
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=25),
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
+3 -2
View File
@@ -21,7 +21,7 @@ router = APIRouter()
def query_messages(
username: str = None,
api_client_id: str = None,
max_count: int = Query(10, gt=0, le=25),
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
@@ -63,7 +63,8 @@ def get_message(
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
# Unexptcted message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
+1 -2
View File
@@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@@ -15,4 +14,4 @@ def get_message_stats(
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return protocol.SystemStats(**pr.get_stats())
return pr.get_stats()
+1 -1
View File
@@ -18,7 +18,7 @@ router = APIRouter()
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=25),
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
+1
View File
@@ -36,6 +36,7 @@ class OasstErrorCode(IntEnum):
USER_NOT_SPECIFIED = 2005
NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
INVALID_MESSAGE = 2008
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
TASK_PAYLOAD_TYPE_MISMATCH = 2102
+32 -7
View File
@@ -13,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN
@@ -57,6 +58,7 @@ class PromptRepository:
return user
def validate_frontend_message_id(self, message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
@@ -450,6 +452,13 @@ class PromptRepository:
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
"""
Fetch a conversation with multiple possible replies to it.
This function finds a random message with >1 replies,
forms a conversation from the corresponding message tree root up to this message
and fetches up to max_size possible replies in continuation to this conversation.
"""
parent = self.db.query(Message.id).filter(Message.children_count > 1)
if message_role:
parent = parent.filter(Message.role == message_role)
@@ -483,6 +492,9 @@ class PromptRepository:
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
@@ -565,6 +577,11 @@ class PromptRepository:
return list(_traverse_subtree(root))
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
"""
Find all descendant messages to this message.
This function creates a subtree of messages starting from given root message.
"""
if isinstance(message, UUID):
message = self.fetch_message(message)
@@ -640,6 +657,9 @@ class PromptRepository:
return messages.all()
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
"""
Marks deleted messages and all their descendants.
"""
if isinstance(messages, (Message, UUID)):
messages = [messages]
@@ -666,14 +686,19 @@ class PromptRepository:
self.db.commit()
def get_stats(self):
def get_stats(self) -> SystemStats:
"""
Get data stats such as number of all messages in the system,
number of deleted and active messages and number of message trees.
"""
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
query = deleted.union_all(nthreads)
result = {k: v for k, v in query.all()}
return {
"all": result.get(True, 0) + result.get(False, 0),
"active": result.get(False, 0),
"deleted": result.get(True, 0),
"threads": result.get(None, 0),
}
return SystemStats(
all=result.get(True, 0) + result.get(False, 0),
active=result.get(False, 0),
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
@@ -277,4 +277,4 @@ class SystemStats(BaseModel):
all: int = 0
active: int = 0
deleted: int = 0
threads: int = 0
message_trees: int = 0