split message api endpoints

This commit is contained in:
Igor Miagkov
2022-12-30 20:02:23 +04:00
committed by Andreas Köpf
parent 13d01b5a2f
commit 8e1d80956a
8 changed files with 470 additions and 302 deletions
+5 -2
View File
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import management, tasks, text_labels
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, tasks, text_labels, users
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
api_router.include_router(management.router, prefix="/management", tags=["management"])
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@router.get("/{message_id}")
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@router.get("/{message_id}/conversation")
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree")
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
@router.get("/{message_id}/children")
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
messages = pr.fetch_message_children(message.id)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.get("/{message_id}/descendants")
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
descendants = pr.fetch_post_descendants(message)
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree")
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree")
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return utils.prepare_tree([message, *children], message.id)
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
router = APIRouter()
@router.get("/{username}/messages")
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=25),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.delete("/{username}/messages")
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
-291
View File
@@ -1,291 +0,0 @@
# -*- coding: utf-8 -*-
import datetime
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Response
from oasst_backend.api import deps
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient, Post
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
router = APIRouter()
def _prepare_conversation(messages: list[Post]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
)
return protocol.Conversation(messages=conv_messages)
def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
id=message.id,
parent_id=message.parent_id,
text=message.payload.payload.text,
is_assistant=(message.role == "assistant"),
)
)
return protocol.MessageTree(id=tree_id, messages=tree_messages)
@router.get("/message")
def query_messages(
username: str = None,
api_client_id: str = None,
max_count: int = Query(10, gt=0, le=25),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query messages.
"""
if not api_client.trusted and (api_client_id != api_client.id):
# Unprivileged api client asks for foreign messages
return []
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
max_count=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.get("/message/{message_id}")
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
post = pr.fetch_post(message_id)
if not isinstance(post.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant"))
@router.get("/frontend_message/{message_id}")
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@router.get("/message/{message_id}/conversation")
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_conversation(message_id)
return _prepare_conversation(messages)
@router.get("/frontend_message/{message_id}/conversation")
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
messages = pr.fetch_message_conversation(message)
return _prepare_conversation(messages)
@router.get("/message/{message_id}/tree")
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
tree = pr.fetch_message_tree(message)
return _prepare_tree(tree, message.thread_id)
@router.get("/frontend_message/{message_id}/tree")
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
tree = pr.fetch_message_tree(message)
return _prepare_tree(tree, message.thread_id)
@router.get("/message/{message_id}/children")
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
return pr.fetch_message_children(message_id)
@router.get("/frontend_message/{message_id}/children")
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
return pr.fetch_message_children(message)
@router.get("/message/{message_id}/descendants")
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
descendants = pr.fetch_post_descendants(message)
return _prepare_tree(descendants, message.id)
@router.get("/frontend_message/{message_id}/descendants")
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
descendants = pr.fetch_post_descendants(message)
return _prepare_tree(descendants, message.id)
@router.get("/message/{message_id}/longest_conversation_in_tree")
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return _prepare_conversation(conv)
@router.get("/frontend_message/{message_id}/longest_conversation_in_tree")
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return _prepare_conversation(conv)
@router.get("/message/{message_id}/max_children_in_tree")
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return _prepare_tree([message, *children], message.id)
@router.get("/frontend_message/{message_id}/max_children_in_tree")
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return _prepare_tree([message, *children], message.id)
@router.delete("/message/{message_id}")
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
return Response(status_code=HTTP_200_OK)
@router.delete("/user/{username}/message")
def mark_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
+159
View File
@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Response
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
router = APIRouter()
@router.get("/")
def query_messages(
username: str = None,
api_client_id: str = None,
max_count: int = Query(10, gt=0, le=25),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
allow_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if allow_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.get("/{message_id}")
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@router.get("/{message_id}/conversation")
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree")
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
@router.get("/{message_id}/children")
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_children(message_id)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.get("/{message_id}/descendants")
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
descendants = pr.fetch_post_descendants(message)
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree")
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree")
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return utils.prepare_tree([message, *children], message.id)
@router.delete("/{message_id}")
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
return Response(status_code=HTTP_200_OK)
+60
View File
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
router = APIRouter()
@router.get("/{user_id}/messages")
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=25),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query user messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.delete("/{user_id}/messages")
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
+38
View File
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from uuid import UUID
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import Post
from oasst_backend.models.db_payload import PostPayload
from oasst_shared.schemas import protocol
def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
)
return protocol.Conversation(messages=conv_messages)
def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
id=message.id,
parent_id=message.parent_id,
text=message.payload.payload.text,
is_assistant=(message.role == "assistant"),
)
)
return protocol.MessageTree(id=tree_id, messages=tree_messages)
+31 -9
View File
@@ -15,6 +15,7 @@ from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN
class PromptRepository:
@@ -477,8 +478,11 @@ class PromptRepository:
return conversation, replies
def fetch_message(self, message_id: UUID) -> Optional[Message]:
return self.db.query(Message).filter(Message.id == message_id).one()
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
self.validate_frontend_message_id(frontend_message_id)
@@ -586,15 +590,27 @@ class PromptRepository:
def query_messages(
self,
username: str = None,
api_client_id: str = None,
user_id: Optional[UUID] = None,
username: Optional[str] = None,
api_client_id: Optional[UUID] = None,
desc: bool = True,
max_count: int = 10,
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
limit: Optional[int] = 10,
start_date: Optional[datetime.datetime] = None,
end_date: Optional[datetime.datetime] = None,
only_roots: bool = False,
deleted: Optional[bool] = None,
) -> list[Post]:
if not self.api_client.trusted and not api_client_id:
# Let unprivileged api clients query their own messages without api_client_id being set
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
messages = self.db.query(Post)
if user_id:
messages = messages.filter(Post.person_id == user_id)
if username:
messages = messages.join(Person)
messages = messages.filter(Person.username == username)
@@ -609,13 +625,19 @@ class PromptRepository:
if only_roots:
messages = messages.filter(Post.parent_id.is_(None))
if deleted is not None:
messages = messages.filter(Post.deleted == deleted)
if desc:
messages = messages.order_by(Post.created_date.desc())
else:
messages = messages.order_by(Post.created_date.asc())
messages = messages.limit(max_count).all()
return messages
if limit is not None:
messages = messages.limit(limit)
# TODO: Pagination could be great at some point
return messages.all()
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
if isinstance(messages, (Post, UUID)):