management api

This commit is contained in:
Igor Miagkov
2022-12-30 06:03:39 +04:00
committed by Andreas Köpf
parent e485ebcd43
commit 13d01b5a2f
8 changed files with 507 additions and 2 deletions
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""add deleted field to post
Revision ID: 6cb49da61b74
Revises: 73ce3675c1f5
Create Date: 2022-12-30 06:54:47.110204
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "6cb49da61b74"
down_revision = "73ce3675c1f5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("post", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("post", "deleted")
# ### end Alembic commands ###
+22 -1
View File
@@ -4,7 +4,7 @@ from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Security
from fastapi import Depends, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst_backend.config import settings
@@ -64,3 +64,24 @@ def api_auth(
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
def get_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
return api_auth(api_key, db)
def get_trusted_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
client = api_auth(api_key, db)
if not client.trusted:
raise OasstError(
"Forbidden",
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
return client
+2 -1
View File
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks, text_labels
from oasst_backend.api.v1 import management, tasks, text_labels
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
api_router.include_router(management.router, prefix="/management", tags=["management"])
+291
View File
@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
import datetime
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Response
from oasst_backend.api import deps
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient, Post
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
router = APIRouter()
def _prepare_conversation(messages: list[Post]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
)
return protocol.Conversation(messages=conv_messages)
def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
id=message.id,
parent_id=message.parent_id,
text=message.payload.payload.text,
is_assistant=(message.role == "assistant"),
)
)
return protocol.MessageTree(id=tree_id, messages=tree_messages)
@router.get("/message")
def query_messages(
username: str = None,
api_client_id: str = None,
max_count: int = Query(10, gt=0, le=25),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query messages.
"""
if not api_client.trusted and (api_client_id != api_client.id):
# Unprivileged api client asks for foreign messages
return []
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
max_count=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
@router.get("/message/{message_id}")
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
post = pr.fetch_post(message_id)
if not isinstance(post.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant"))
@router.get("/frontend_message/{message_id}")
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@router.get("/message/{message_id}/conversation")
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_conversation(message_id)
return _prepare_conversation(messages)
@router.get("/frontend_message/{message_id}/conversation")
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
messages = pr.fetch_message_conversation(message)
return _prepare_conversation(messages)
@router.get("/message/{message_id}/tree")
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
tree = pr.fetch_message_tree(message)
return _prepare_tree(tree, message.thread_id)
@router.get("/frontend_message/{message_id}/tree")
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
tree = pr.fetch_message_tree(message)
return _prepare_tree(tree, message.thread_id)
@router.get("/message/{message_id}/children")
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
return pr.fetch_message_children(message_id)
@router.get("/frontend_message/{message_id}/children")
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
return pr.fetch_message_children(message)
@router.get("/message/{message_id}/descendants")
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
descendants = pr.fetch_post_descendants(message)
return _prepare_tree(descendants, message.id)
@router.get("/frontend_message/{message_id}/descendants")
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
descendants = pr.fetch_post_descendants(message)
return _prepare_tree(descendants, message.id)
@router.get("/message/{message_id}/longest_conversation_in_tree")
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return _prepare_conversation(conv)
@router.get("/frontend_message/{message_id}/longest_conversation_in_tree")
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
return _prepare_conversation(conv)
@router.get("/message/{message_id}/max_children_in_tree")
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return _prepare_tree([message, *children], message.id)
@router.get("/frontend_message/{message_id}/max_children_in_tree")
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
return _prepare_tree([message, *children], message.id)
@router.delete("/message/{message_id}")
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
return Response(status_code=HTTP_200_OK)
@router.delete("/user/{username}/message")
def mark_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
+1
View File
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
SERVER_ERROR = 3
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
+2
View File
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
@@ -34,3 +35,4 @@ class Message(SQLModel, table=True):
lang: str = Field(nullable=False, max_length=200, default="en-US")
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
+149
View File
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
@@ -10,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy import update
from sqlmodel import Session, func
@@ -492,3 +496,148 @@ class PromptRepository:
task.done = True
self.db.add(task)
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]:
"""
Pick messages from a collection so that the result makes a linear conversation
starting from a message tree root and up to the given message.
Returns an ordered list of messages starting from the message tree root.
"""
if isinstance(messages, list):
messages = {m.id: m for m in messages}
if not isinstance(messages, dict):
# This should not normally happen
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv = [last_message]
while conv[-1].parent_id:
if conv[-1].parent_id not in messages:
# Can't form a continuous conversation
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
parent_message = messages[conv[-1].parent_id]
conv.append(parent_message)
return list(reversed(conv))
def fetch_message_conversation(self, message: Post | UUID) -> list[Post]:
"""
Fetch a conversation from the tree root and up to this message.
"""
if isinstance(message, UUID):
message = self.fetch_post(message)
tree_messages = self.fetch_thread(message.thread_id)
return self.trace_conversation(tree_messages, message)
def fetch_message_tree(self, message: Post | UUID) -> list[Post]:
"""
Fetch message tree this message belongs to.
"""
if isinstance(message, UUID):
message = self.fetch_post(message)
return self.fetch_thread(message.thread_id)
def fetch_message_children(self, message: Post | UUID) -> list[Post]:
"""
Get all direct children of this message
"""
if isinstance(message, Post):
message = message.id
children = self.db.query(Post).filter(Post.parent_id == message).all()
return children
@staticmethod
def trace_descendants(root: Post, messages: list[Post]) -> list[Post]:
children = defaultdict(list)
for msg in messages:
children[msg.parent_id].append(msg)
def _traverse_subtree(m: Post):
for child in children[m.id]:
yield child
yield from _traverse_subtree(child)
return list(_traverse_subtree(root))
def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]:
if isinstance(message, UUID):
message = self.fetch_post(message)
desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth)
if max_depth is not None:
desc = desc.filter(Post.depth <= max_depth)
desc = desc.all()
return self.trace_descendants(message, desc)
def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]:
tree = self.fetch_message_tree(message)
max_message = max(tree, key=lambda m: m.depth)
return self.trace_conversation(tree, max_message)
def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]:
tree = self.fetch_message_tree(message)
max_message = max(tree, key=lambda m: m.children_count)
return max_message, [m for m in tree if m.parent_id == max_message.id]
def query_messages(
self,
username: str = None,
api_client_id: str = None,
desc: bool = True,
max_count: int = 10,
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
) -> list[Post]:
messages = self.db.query(Post)
if username:
messages = messages.join(Person)
messages = messages.filter(Person.username == username)
if api_client_id:
messages = messages.filter(Post.api_client_id == api_client_id)
if start_date:
messages = messages.filter(Post.created_date >= start_date)
if end_date:
messages = messages.filter(Post.created_date < end_date)
if only_roots:
messages = messages.filter(Post.parent_id.is_(None))
if desc:
messages = messages.order_by(Post.created_date.desc())
else:
messages = messages.order_by(Post.created_date.asc())
messages = messages.limit(max_count).all()
return messages
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
if isinstance(messages, (Post, UUID)):
messages = [messages]
ids = []
for message in messages:
if isinstance(message, UUID):
ids.append(message)
elif isinstance(message, Post):
ids.append(message.id)
else:
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
query = update(Post).where(Post.id.in_(ids)).values(deleted=True)
self.db.execute(query)
parent_ids = ids
if recursive:
while parent_ids:
query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id)
parent_ids = self.db.execute(query).scalars().all()
self.db.commit()
@@ -38,6 +38,18 @@ class Conversation(BaseModel):
messages: list[ConversationMessage] = []
class Message(ConversationMessage):
id: UUID
parent_id: Optional[UUID] = None
class MessageTree(BaseModel):
"""All messages belonging to the same message tree."""
id: UUID
messages: list[Message] = []
class TaskRequest(BaseModel):
"""The frontend asks the backend for a task."""