mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
fixes
This commit is contained in:
committed by
Andreas Köpf
parent
6d98ba1f75
commit
ef3a35ff9c
@@ -24,7 +24,8 @@ def get_message_by_frontend_id(
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
# Unexpected message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ router = APIRouter()
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
|
||||
@@ -21,7 +21,7 @@ router = APIRouter()
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
@@ -63,7 +63,8 @@ def get_message(
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
# Unexptcted message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
@@ -15,4 +14,4 @@ def get_message_stats(
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return protocol.SystemStats(**pr.get_stats())
|
||||
return pr.get_stats()
|
||||
|
||||
@@ -18,7 +18,7 @@ router = APIRouter()
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
|
||||
@@ -36,6 +36,7 @@ class OasstErrorCode(IntEnum):
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
INVALID_MESSAGE = 2008
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
TASK_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
|
||||
@@ -13,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
@@ -57,6 +58,7 @@ class PromptRepository:
|
||||
return user
|
||||
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
@@ -450,6 +452,13 @@ class PromptRepository:
|
||||
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
"""
|
||||
Fetch a conversation with multiple possible replies to it.
|
||||
|
||||
This function finds a random message with >1 replies,
|
||||
forms a conversation from the corresponding message tree root up to this message
|
||||
and fetches up to max_size possible replies in continuation to this conversation.
|
||||
"""
|
||||
parent = self.db.query(Message.id).filter(Message.children_count > 1)
|
||||
if message_role:
|
||||
parent = parent.filter(Message.role == message_role)
|
||||
@@ -483,6 +492,9 @@ class PromptRepository:
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
@@ -565,6 +577,11 @@ class PromptRepository:
|
||||
return list(_traverse_subtree(root))
|
||||
|
||||
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
|
||||
"""
|
||||
Find all descendant messages to this message.
|
||||
|
||||
This function creates a subtree of messages starting from given root message.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_message(message)
|
||||
|
||||
@@ -640,6 +657,9 @@ class PromptRepository:
|
||||
return messages.all()
|
||||
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
Marks deleted messages and all their descendants.
|
||||
"""
|
||||
if isinstance(messages, (Message, UUID)):
|
||||
messages = [messages]
|
||||
|
||||
@@ -666,14 +686,19 @@ class PromptRepository:
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_stats(self):
|
||||
def get_stats(self) -> SystemStats:
|
||||
"""
|
||||
Get data stats such as number of all messages in the system,
|
||||
number of deleted and active messages and number of message trees.
|
||||
"""
|
||||
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
|
||||
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
|
||||
query = deleted.union_all(nthreads)
|
||||
result = {k: v for k, v in query.all()}
|
||||
return {
|
||||
"all": result.get(True, 0) + result.get(False, 0),
|
||||
"active": result.get(False, 0),
|
||||
"deleted": result.get(True, 0),
|
||||
"threads": result.get(None, 0),
|
||||
}
|
||||
|
||||
return SystemStats(
|
||||
all=result.get(True, 0) + result.get(False, 0),
|
||||
active=result.get(False, 0),
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
@@ -277,4 +277,4 @@ class SystemStats(BaseModel):
|
||||
all: int = 0
|
||||
active: int = 0
|
||||
deleted: int = 0
|
||||
threads: int = 0
|
||||
message_trees: int = 0
|
||||
|
||||
Reference in New Issue
Block a user