From 8e1d80956acebbe994a5edd5e1871de6197ecbb6 Mon Sep 17 00:00:00 2001 From: Igor Miagkov Date: Fri, 30 Dec 2022 20:02:23 +0400 Subject: [PATCH] split message api endpoints --- backend/oasst_backend/api/v1/api.py | 7 +- .../oasst_backend/api/v1/frontend_messages.py | 117 +++++++ .../oasst_backend/api/v1/frontend_users.py | 60 ++++ backend/oasst_backend/api/v1/management.py | 291 ------------------ backend/oasst_backend/api/v1/messages.py | 159 ++++++++++ backend/oasst_backend/api/v1/users.py | 60 ++++ backend/oasst_backend/api/v1/utils.py | 38 +++ backend/oasst_backend/prompt_repository.py | 40 ++- 8 files changed, 470 insertions(+), 302 deletions(-) create mode 100644 backend/oasst_backend/api/v1/frontend_messages.py create mode 100644 backend/oasst_backend/api/v1/frontend_users.py delete mode 100644 backend/oasst_backend/api/v1/management.py create mode 100644 backend/oasst_backend/api/v1/messages.py create mode 100644 backend/oasst_backend/api/v1/users.py create mode 100644 backend/oasst_backend/api/v1/utils.py diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 7c9eb493..b39d8d59 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter -from oasst_backend.api.v1 import management, tasks, text_labels +from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, tasks, text_labels, users api_router = APIRouter() api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"]) -api_router.include_router(management.router, prefix="/management", tags=["management"]) +api_router.include_router(messages.router, prefix="/messages", tags=["messages"]) +api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"]) +api_router.include_router(users.router, prefix="/users", tags=["users"]) +api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py new file mode 100644 index 00000000..acdd388a --- /dev/null +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- + +from fastapi import APIRouter, Depends +from oasst_backend.api import deps +from oasst_backend.api.v1 import utils +from oasst_backend.exceptions import OasstError, OasstErrorCode +from oasst_backend.models import ApiClient +from oasst_backend.models.db_payload import PostPayload +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol +from sqlmodel import Session + +router = APIRouter() + + +@router.get("/{message_id}") +def get_message_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a message by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True) + + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + + return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + + +@router.get("/{message_id}/conversation") +def get_conv_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a conversation from the tree root and up to the message with given frontend ID. + """ + + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + messages = pr.fetch_message_conversation(message) + return utils.prepare_conversation(messages) + + +@router.get("/{message_id}/tree") +def get_tree_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + Message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + tree = pr.fetch_message_tree(message) + return utils.prepare_tree(tree, message.thread_id) + + +@router.get("/{message_id}/children") +def get_children_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + messages = pr.fetch_message_children(message.id) + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.get("/{message_id}/descendants") +def get_descendants_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a subtree which starts with this message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + descendants = pr.fetch_post_descendants(message) + return utils.prepare_tree(descendants, message.id) + + +@router.get("/{message_id}/longest_conversation_in_tree") +def get_longest_conv_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get the longest conversation from the tree of the message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + conv = pr.fetch_longest_conversation(message.thread_id) + return utils.prepare_conversation(conv) + + +@router.get("/{message_id}/max_children_in_tree") +def get_max_children_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get message with the most children from the tree of the provided message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + message, children = pr.fetch_message_with_max_children(message.thread_id) + return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py new file mode 100644 index 00000000..f5df2b24 --- /dev/null +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import datetime +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol +from sqlmodel import Session +from starlette.responses import Response +from starlette.status import HTTP_200_OK + +router = APIRouter() + + +@router.get("/{username}/messages") +def query_frontend_user_messages( + username: str, + api_client_id: UUID = None, + max_count: int = Query(10, gt=0, le=25), + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + only_roots: bool = False, + desc: bool = True, + include_deleted: bool = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + """ + Query frontend user messages. + """ + pr = PromptRepository(db, api_client, user=None) + messages = pr.query_messages( + username=username, + api_client_id=api_client_id, + desc=desc, + limit=max_count, + start_date=start_date, + end_date=end_date, + only_roots=only_roots, + deleted=None if include_deleted else False, + ) + + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.delete("/{username}/messages") +def mark_frontend_user_messages_deleted( + username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) +): + pr = PromptRepository(db, api_client, None) + messages = pr.query_messages(username=username, api_client_id=api_client.id) + pr.mark_messages_deleted(messages) + return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/management.py b/backend/oasst_backend/api/v1/management.py deleted file mode 100644 index 0e3be690..00000000 --- a/backend/oasst_backend/api/v1/management.py +++ /dev/null @@ -1,291 +0,0 @@ -# -*- coding: utf-8 -*- -import datetime -from http import HTTPStatus -from uuid import UUID - -from fastapi import APIRouter, Depends, Query, Response -from oasst_backend.api import deps -from oasst_backend.exceptions import OasstError, OasstErrorCode -from oasst_backend.models import ApiClient, Post -from oasst_backend.models.db_payload import PostPayload -from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.schemas import protocol -from sqlmodel import Session -from starlette.status import HTTP_200_OK - -router = APIRouter() - - -def _prepare_conversation(messages: list[Post]) -> protocol.Conversation: - conv_messages = [] - for message in messages: - if not isinstance(message.payload.payload, PostPayload): - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) - conv_messages.append( - protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) - ) - - return protocol.Conversation(messages=conv_messages) - - -def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree: - tree_messages = [] - for message in tree: - if not isinstance(message.payload.payload, PostPayload): - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) - tree_messages.append( - protocol.Message( - id=message.id, - parent_id=message.parent_id, - text=message.payload.payload.text, - is_assistant=(message.role == "assistant"), - ) - ) - - return protocol.MessageTree(id=tree_id, messages=tree_messages) - - -@router.get("/message") -def query_messages( - username: str = None, - api_client_id: str = None, - max_count: int = Query(10, gt=0, le=25), - start_date: datetime.datetime = None, - end_date: datetime.datetime = None, - only_roots: bool = False, - desc: bool = True, - api_client: ApiClient = Depends(deps.get_api_client), - db: Session = Depends(deps.get_db), -): - """ - Query messages. - """ - if not api_client.trusted and (api_client_id != api_client.id): - # Unprivileged api client asks for foreign messages - return [] - - pr = PromptRepository(db, api_client, user=None) - messages = pr.query_messages( - username=username, - api_client_id=api_client_id, - desc=desc, - max_count=max_count, - start_date=start_date, - end_date=end_date, - only_roots=only_roots, - ) - - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] - - -@router.get("/message/{message_id}") -def get_message( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a message by its internal ID. - """ - pr = PromptRepository(db, api_client, user=None) - post = pr.fetch_post(message_id) - if not isinstance(post.payload.payload, PostPayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) - - return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant")) - - -@router.get("/frontend_message/{message_id}") -def get_message_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a message by its frontend ID. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True) - - if not isinstance(message.payload.payload, PostPayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) - - return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) - - -@router.get("/message/{message_id}/conversation") -def get_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a conversation from the tree root and up to the message with given internal ID. - """ - - pr = PromptRepository(db, api_client, user=None) - messages = pr.fetch_message_conversation(message_id) - return _prepare_conversation(messages) - - -@router.get("/frontend_message/{message_id}/conversation") -def get_conv_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a conversation from the tree root and up to the message with given frontend ID. - """ - - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - messages = pr.fetch_message_conversation(message) - return _prepare_conversation(messages) - - -@router.get("/message/{message_id}/tree") -def get_tree( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get all messages belonging to the same message tree. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - tree = pr.fetch_message_tree(message) - return _prepare_tree(tree, message.thread_id) - - -@router.get("/frontend_message/{message_id}/tree") -def get_tree_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get all messages belonging to the same message tree. - Message is identified by its frontend ID. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - tree = pr.fetch_message_tree(message) - return _prepare_tree(tree, message.thread_id) - - -@router.get("/message/{message_id}/children") -def get_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get all messages belonging to the same message tree. - """ - pr = PromptRepository(db, api_client, user=None) - return pr.fetch_message_children(message_id) - - -@router.get("/frontend_message/{message_id}/children") -def get_children_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get all messages belonging to the same message tree. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - return pr.fetch_message_children(message) - - -@router.get("/message/{message_id}/descendants") -def get_descendants( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a subtree which starts with this message. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - descendants = pr.fetch_post_descendants(message) - return _prepare_tree(descendants, message.id) - - -@router.get("/frontend_message/{message_id}/descendants") -def get_descendants_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get a subtree which starts with this message. - The message is identified by its frontend ID. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - descendants = pr.fetch_post_descendants(message) - return _prepare_tree(descendants, message.id) - - -@router.get("/message/{message_id}/longest_conversation_in_tree") -def get_longest_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get the longest conversation from the tree of the message. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - conv = pr.fetch_longest_conversation(message.thread_id) - return _prepare_conversation(conv) - - -@router.get("/frontend_message/{message_id}/longest_conversation_in_tree") -def get_longest_conv_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get the longest conversation from the tree of the message. - The message is identified by its frontend ID. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - conv = pr.fetch_longest_conversation(message.thread_id) - return _prepare_conversation(conv) - - -@router.get("/message/{message_id}/max_children_in_tree") -def get_max_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get message with the most children from the tree of the provided message. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - message, children = pr.fetch_message_with_max_children(message.thread_id) - return _prepare_tree([message, *children], message.id) - - -@router.get("/frontend_message/{message_id}/max_children_in_tree") -def get_max_children_by_frontend_id( - message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) -): - """ - Get message with the most children from the tree of the provided message. - The message is identified by its frontend ID. - """ - pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - message, children = pr.fetch_message_with_max_children(message.thread_id) - return _prepare_tree([message, *children], message.id) - - -@router.delete("/message/{message_id}") -def mark_message_deleted( - message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) -): - pr = PromptRepository(db, api_client, None) - pr.mark_messages_deleted(message_id) - return Response(status_code=HTTP_200_OK) - - -@router.delete("/user/{username}/message") -def mark_user_messages_deleted( - username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) -): - pr = PromptRepository(db, api_client, None) - messages = pr.query_messages(username=username, api_client_id=api_client.id) - pr.mark_messages_deleted(messages) - return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py new file mode 100644 index 00000000..bd8e7f12 --- /dev/null +++ b/backend/oasst_backend/api/v1/messages.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- + +import datetime +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, Response +from oasst_backend.api import deps +from oasst_backend.api.v1 import utils +from oasst_backend.exceptions import OasstError, OasstErrorCode +from oasst_backend.models import ApiClient +from oasst_backend.models.db_payload import PostPayload +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol +from sqlmodel import Session +from starlette.status import HTTP_200_OK + +router = APIRouter() + + +@router.get("/") +def query_messages( + username: str = None, + api_client_id: str = None, + max_count: int = Query(10, gt=0, le=25), + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + only_roots: bool = False, + desc: bool = True, + allow_deleted: bool = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + """ + Query messages. + """ + pr = PromptRepository(db, api_client, user=None) + messages = pr.query_messages( + username=username, + api_client_id=api_client_id, + desc=desc, + limit=max_count, + start_date=start_date, + end_date=end_date, + only_roots=only_roots, + deleted=None if allow_deleted else False, + ) + + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.get("/{message_id}") +def get_message( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a message by its internal ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + + return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + + +@router.get("/{message_id}/conversation") +def get_conv( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a conversation from the tree root and up to the message with given internal ID. + """ + + pr = PromptRepository(db, api_client, user=None) + messages = pr.fetch_message_conversation(message_id) + return utils.prepare_conversation(messages) + + +@router.get("/{message_id}/tree") +def get_tree( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + tree = pr.fetch_message_tree(message) + return utils.prepare_tree(tree, message.thread_id) + + +@router.get("/{message_id}/children") +def get_children( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + messages = pr.fetch_message_children(message_id) + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.get("/{message_id}/descendants") +def get_descendants( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a subtree which starts with this message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + descendants = pr.fetch_post_descendants(message) + return utils.prepare_tree(descendants, message.id) + + +@router.get("/{message_id}/longest_conversation_in_tree") +def get_longest_conv( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get the longest conversation from the tree of the message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + conv = pr.fetch_longest_conversation(message.thread_id) + return utils.prepare_conversation(conv) + + +@router.get("/{message_id}/max_children_in_tree") +def get_max_children( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get message with the most children from the tree of the provided message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + message, children = pr.fetch_message_with_max_children(message.thread_id) + return utils.prepare_tree([message, *children], message.id) + + +@router.delete("/{message_id}") +def mark_message_deleted( + message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) +): + pr = PromptRepository(db, api_client, None) + pr.mark_messages_deleted(message_id) + return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py new file mode 100644 index 00000000..000b0970 --- /dev/null +++ b/backend/oasst_backend/api/v1/users.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import datetime +from uuid import UUID + +from fastapi import APIRouter, Depends, Query +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol +from sqlmodel import Session +from starlette.responses import Response +from starlette.status import HTTP_200_OK + +router = APIRouter() + + +@router.get("/{user_id}/messages") +def query_user_messages( + user_id: UUID, + api_client_id: UUID = None, + max_count: int = Query(10, gt=0, le=25), + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + only_roots: bool = False, + desc: bool = True, + include_deleted: bool = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + """ + Query user messages. + """ + pr = PromptRepository(db, api_client, user=None) + messages = pr.query_messages( + user_id=user_id, + api_client_id=api_client_id, + desc=desc, + limit=max_count, + start_date=start_date, + end_date=end_date, + only_roots=only_roots, + deleted=None if include_deleted else False, + ) + + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.delete("/{user_id}/messages") +def mark_user_messages_deleted( + user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) +): + pr = PromptRepository(db, api_client, None) + messages = pr.query_messages(user_id=user_id) + pr.mark_messages_deleted(messages) + return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py new file mode 100644 index 00000000..48f11038 --- /dev/null +++ b/backend/oasst_backend/api/v1/utils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from http import HTTPStatus +from uuid import UUID + +from oasst_backend.exceptions import OasstError, OasstErrorCode +from oasst_backend.models import Post +from oasst_backend.models.db_payload import PostPayload +from oasst_shared.schemas import protocol + + +def prepare_conversation(messages: list[Post]) -> protocol.Conversation: + conv_messages = [] + for message in messages: + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + conv_messages.append( + protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + ) + + return protocol.Conversation(messages=conv_messages) + + +def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree: + tree_messages = [] + for message in tree: + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + tree_messages.append( + protocol.Message( + id=message.id, + parent_id=message.parent_id, + text=message.payload.payload.text, + is_assistant=(message.role == "assistant"), + ) + ) + + return protocol.MessageTree(id=tree_id, messages=tree_messages) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6db35019..05902b5a 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -15,6 +15,7 @@ from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlalchemy import update from sqlmodel import Session, func +from starlette.status import HTTP_403_FORBIDDEN class PromptRepository: @@ -477,8 +478,11 @@ class PromptRepository: return conversation, replies - def fetch_message(self, message_id: UUID) -> Optional[Message]: - return self.db.query(Message).filter(Message.id == message_id).one() + def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: + message = self.db.query(Message).filter(Message.id == message_id).one_or_none() + if fail_if_missing and not message: + raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND) + return message def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): self.validate_frontend_message_id(frontend_message_id) @@ -586,15 +590,27 @@ class PromptRepository: def query_messages( self, - username: str = None, - api_client_id: str = None, + user_id: Optional[UUID] = None, + username: Optional[str] = None, + api_client_id: Optional[UUID] = None, desc: bool = True, - max_count: int = 10, - start_date: datetime.datetime = None, - end_date: datetime.datetime = None, + limit: Optional[int] = 10, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None, only_roots: bool = False, + deleted: Optional[bool] = None, ) -> list[Post]: + if not self.api_client.trusted and not api_client_id: + # Let unprivileged api clients query their own messages without api_client_id being set + api_client_id = self.api_client.id + + if not self.api_client.trusted and api_client_id != self.api_client.id: + # Unprivileged api client asks for foreign messages + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + messages = self.db.query(Post) + if user_id: + messages = messages.filter(Post.person_id == user_id) if username: messages = messages.join(Person) messages = messages.filter(Person.username == username) @@ -609,13 +625,19 @@ class PromptRepository: if only_roots: messages = messages.filter(Post.parent_id.is_(None)) + if deleted is not None: + messages = messages.filter(Post.deleted == deleted) + if desc: messages = messages.order_by(Post.created_date.desc()) else: messages = messages.order_by(Post.created_date.asc()) - messages = messages.limit(max_count).all() - return messages + if limit is not None: + messages = messages.limit(limit) + + # TODO: Pagination could be great at some point + return messages.all() def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True): if isinstance(messages, (Post, UUID)):