mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
split message api endpoints
This commit is contained in:
committed by
Andreas Köpf
parent
13d01b5a2f
commit
8e1d80956a
@@ -1,8 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import management, tasks, text_labels
|
||||
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, tasks, text_labels, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
|
||||
api_router.include_router(management.router, prefix="/management", tags=["management"])
|
||||
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
|
||||
api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
|
||||
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return utils.prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{username}/messages")
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{username}/messages")
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -1,291 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient, Post
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _prepare_conversation(messages: list[Post]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
|
||||
def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(
|
||||
protocol.Message(
|
||||
id=message.id,
|
||||
parent_id=message.parent_id,
|
||||
text=message.payload.payload.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
)
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
|
||||
@router.get("/message")
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
if not api_client.trusted and (api_client_id != api_client.id):
|
||||
# Unprivileged api client asks for foreign messages
|
||||
return []
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
max_count=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/message/{message_id}")
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
post = pr.fetch_post(message_id)
|
||||
if not isinstance(post.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}")
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
|
||||
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/conversation")
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return _prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/conversation")
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return _prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/tree")
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return _prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/tree")
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return _prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/children")
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
return pr.fetch_message_children(message_id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/children")
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
return pr.fetch_message_children(message)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/descendants")
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return _prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/descendants")
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return _prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return _prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return _prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/max_children_in_tree")
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return _prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/max_children_in_tree")
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return _prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/message/{message_id}")
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
|
||||
@router.delete("/user/{username}/message")
|
||||
def mark_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,159 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
allow_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if allow_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return utils.prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages")
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{user_id}/messages")
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,38 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import Post
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
|
||||
def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(
|
||||
protocol.Message(
|
||||
id=message.id,
|
||||
parent_id=message.parent_id,
|
||||
text=message.payload.payload.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
)
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
@@ -15,6 +15,7 @@ from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -477,8 +478,11 @@ class PromptRepository:
|
||||
|
||||
return conversation, replies
|
||||
|
||||
def fetch_message(self, message_id: UUID) -> Optional[Message]:
|
||||
return self.db.query(Message).filter(Message.id == message_id).one()
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
@@ -586,15 +590,27 @@ class PromptRepository:
|
||||
|
||||
def query_messages(
|
||||
self,
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
desc: bool = True,
|
||||
max_count: int = 10,
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
limit: Optional[int] = 10,
|
||||
start_date: Optional[datetime.datetime] = None,
|
||||
end_date: Optional[datetime.datetime] = None,
|
||||
only_roots: bool = False,
|
||||
deleted: Optional[bool] = None,
|
||||
) -> list[Post]:
|
||||
if not self.api_client.trusted and not api_client_id:
|
||||
# Let unprivileged api clients query their own messages without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
messages = self.db.query(Post)
|
||||
if user_id:
|
||||
messages = messages.filter(Post.person_id == user_id)
|
||||
if username:
|
||||
messages = messages.join(Person)
|
||||
messages = messages.filter(Person.username == username)
|
||||
@@ -609,13 +625,19 @@ class PromptRepository:
|
||||
if only_roots:
|
||||
messages = messages.filter(Post.parent_id.is_(None))
|
||||
|
||||
if deleted is not None:
|
||||
messages = messages.filter(Post.deleted == deleted)
|
||||
|
||||
if desc:
|
||||
messages = messages.order_by(Post.created_date.desc())
|
||||
else:
|
||||
messages = messages.order_by(Post.created_date.asc())
|
||||
|
||||
messages = messages.limit(max_count).all()
|
||||
return messages
|
||||
if limit is not None:
|
||||
messages = messages.limit(limit)
|
||||
|
||||
# TODO: Pagination could be great at some point
|
||||
return messages.all()
|
||||
|
||||
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
|
||||
if isinstance(messages, (Post, UUID)):
|
||||
|
||||
Reference in New Issue
Block a user