From 13d01b5a2fcdfcdc2e7777e232593d7e4285c048 Mon Sep 17 00:00:00 2001 From: Igor Miagkov Date: Fri, 30 Dec 2022 06:03:39 +0400 Subject: [PATCH] management api --- ...-6cb49da61b74_add_deleted_field_to_post.py | 28 ++ backend/oasst_backend/api/deps.py | 23 +- backend/oasst_backend/api/v1/api.py | 3 +- backend/oasst_backend/api/v1/management.py | 291 ++++++++++++++++++ backend/oasst_backend/exceptions.py | 1 + backend/oasst_backend/models/message.py | 2 + backend/oasst_backend/prompt_repository.py | 149 +++++++++ oasst-shared/oasst_shared/schemas/protocol.py | 12 + 8 files changed, 507 insertions(+), 2 deletions(-) create mode 100644 backend/alembic/versions/2022_12_30_0654-6cb49da61b74_add_deleted_field_to_post.py create mode 100644 backend/oasst_backend/api/v1/management.py diff --git a/backend/alembic/versions/2022_12_30_0654-6cb49da61b74_add_deleted_field_to_post.py b/backend/alembic/versions/2022_12_30_0654-6cb49da61b74_add_deleted_field_to_post.py new file mode 100644 index 00000000..a8f228bd --- /dev/null +++ b/backend/alembic/versions/2022_12_30_0654-6cb49da61b74_add_deleted_field_to_post.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""add deleted field to post + +Revision ID: 6cb49da61b74 +Revises: 73ce3675c1f5 +Create Date: 2022-12-30 06:54:47.110204 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "6cb49da61b74" +down_revision = "73ce3675c1f5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("post", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("post", "deleted") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index 9c4feee2..e0286ba3 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -4,7 +4,7 @@ from secrets import token_hex from typing import Generator from uuid import UUID -from fastapi import Security +from fastapi import Depends, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery from loguru import logger from oasst_backend.config import settings @@ -64,3 +64,24 @@ def api_auth( error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, http_status_code=HTTPStatus.FORBIDDEN, ) + + +def get_api_client( + api_key: APIKey = Depends(get_api_key), + db: Session = Depends(get_db), +): + return api_auth(api_key, db) + + +def get_trusted_api_client( + api_key: APIKey = Depends(get_api_key), + db: Session = Depends(get_db), +): + client = api_auth(api_key, db) + if not client.trusted: + raise OasstError( + "Forbidden", + error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, + http_status_code=HTTPStatus.FORBIDDEN, + ) + return client diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index b54f3dd0..7c9eb493 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter -from oasst_backend.api.v1 import tasks, text_labels +from oasst_backend.api.v1 import management, tasks, text_labels api_router = APIRouter() api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"]) +api_router.include_router(management.router, prefix="/management", tags=["management"]) diff --git a/backend/oasst_backend/api/v1/management.py b/backend/oasst_backend/api/v1/management.py new file mode 100644 index 00000000..0e3be690 --- /dev/null +++ b/backend/oasst_backend/api/v1/management.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +import datetime +from http import HTTPStatus +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, Response +from oasst_backend.api import deps +from oasst_backend.exceptions import OasstError, OasstErrorCode +from oasst_backend.models import ApiClient, Post +from oasst_backend.models.db_payload import PostPayload +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol +from sqlmodel import Session +from starlette.status import HTTP_200_OK + +router = APIRouter() + + +def _prepare_conversation(messages: list[Post]) -> protocol.Conversation: + conv_messages = [] + for message in messages: + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + conv_messages.append( + protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + ) + + return protocol.Conversation(messages=conv_messages) + + +def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree: + tree_messages = [] + for message in tree: + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + tree_messages.append( + protocol.Message( + id=message.id, + parent_id=message.parent_id, + text=message.payload.payload.text, + is_assistant=(message.role == "assistant"), + ) + ) + + return protocol.MessageTree(id=tree_id, messages=tree_messages) + + +@router.get("/message") +def query_messages( + username: str = None, + api_client_id: str = None, + max_count: int = Query(10, gt=0, le=25), + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + only_roots: bool = False, + desc: bool = True, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + """ + Query messages. + """ + if not api_client.trusted and (api_client_id != api_client.id): + # Unprivileged api client asks for foreign messages + return [] + + pr = PromptRepository(db, api_client, user=None) + messages = pr.query_messages( + username=username, + api_client_id=api_client_id, + desc=desc, + max_count=max_count, + start_date=start_date, + end_date=end_date, + only_roots=only_roots, + ) + + return [ + protocol.Message( + id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") + ) + for m in messages + ] + + +@router.get("/message/{message_id}") +def get_message( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a message by its internal ID. + """ + pr = PromptRepository(db, api_client, user=None) + post = pr.fetch_post(message_id) + if not isinstance(post.payload.payload, PostPayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + + return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant")) + + +@router.get("/frontend_message/{message_id}") +def get_message_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a message by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True) + + if not isinstance(message.payload.payload, PostPayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + + return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + + +@router.get("/message/{message_id}/conversation") +def get_conv( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a conversation from the tree root and up to the message with given internal ID. + """ + + pr = PromptRepository(db, api_client, user=None) + messages = pr.fetch_message_conversation(message_id) + return _prepare_conversation(messages) + + +@router.get("/frontend_message/{message_id}/conversation") +def get_conv_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a conversation from the tree root and up to the message with given frontend ID. + """ + + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + messages = pr.fetch_message_conversation(message) + return _prepare_conversation(messages) + + +@router.get("/message/{message_id}/tree") +def get_tree( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + tree = pr.fetch_message_tree(message) + return _prepare_tree(tree, message.thread_id) + + +@router.get("/frontend_message/{message_id}/tree") +def get_tree_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + Message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + tree = pr.fetch_message_tree(message) + return _prepare_tree(tree, message.thread_id) + + +@router.get("/message/{message_id}/children") +def get_children( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + return pr.fetch_message_children(message_id) + + +@router.get("/frontend_message/{message_id}/children") +def get_children_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get all messages belonging to the same message tree. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + return pr.fetch_message_children(message) + + +@router.get("/message/{message_id}/descendants") +def get_descendants( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a subtree which starts with this message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + descendants = pr.fetch_post_descendants(message) + return _prepare_tree(descendants, message.id) + + +@router.get("/frontend_message/{message_id}/descendants") +def get_descendants_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get a subtree which starts with this message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + descendants = pr.fetch_post_descendants(message) + return _prepare_tree(descendants, message.id) + + +@router.get("/message/{message_id}/longest_conversation_in_tree") +def get_longest_conv( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get the longest conversation from the tree of the message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + conv = pr.fetch_longest_conversation(message.thread_id) + return _prepare_conversation(conv) + + +@router.get("/frontend_message/{message_id}/longest_conversation_in_tree") +def get_longest_conv_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get the longest conversation from the tree of the message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + conv = pr.fetch_longest_conversation(message.thread_id) + return _prepare_conversation(conv) + + +@router.get("/message/{message_id}/max_children_in_tree") +def get_max_children( + message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get message with the most children from the tree of the provided message. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post(message_id) + message, children = pr.fetch_message_with_max_children(message.thread_id) + return _prepare_tree([message, *children], message.id) + + +@router.get("/frontend_message/{message_id}/max_children_in_tree") +def get_max_children_by_frontend_id( + message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) +): + """ + Get message with the most children from the tree of the provided message. + The message is identified by its frontend ID. + """ + pr = PromptRepository(db, api_client, user=None) + message = pr.fetch_post_by_frontend_post_id(message_id) + message, children = pr.fetch_message_with_max_children(message.thread_id) + return _prepare_tree([message, *children], message.id) + + +@router.delete("/message/{message_id}") +def mark_message_deleted( + message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) +): + pr = PromptRepository(db, api_client, None) + pr.mark_messages_deleted(message_id) + return Response(status_code=HTTP_200_OK) + + +@router.delete("/user/{username}/message") +def mark_user_messages_deleted( + username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) +): + pr = PromptRepository(db, api_client, None) + messages = pr.query_messages(username=username, api_client_id=api_client.id) + pr.mark_messages_deleted(messages) + return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index ac7366cd..7f88caed 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum): GENERIC_ERROR = 0 DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 + SERVER_ERROR = 3 # 1000-2000: tasks endpoint TASK_INVALID_REQUEST_TYPE = 1000 diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 1425ce98..47512cc7 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg +from sqlalchemy import false from sqlmodel import Field, Index, SQLModel from .payload_column_type import PayloadContainer, payload_column_type @@ -34,3 +35,4 @@ class Message(SQLModel, table=True): lang: str = Field(nullable=False, max_length=200, default="en-US") depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) + deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 15ed3816..6db35019 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- +import datetime import random +from collections import defaultdict +from http import HTTPStatus from typing import Optional from uuid import UUID, uuid4 @@ -10,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema +from sqlalchemy import update from sqlmodel import Session, func @@ -492,3 +496,148 @@ class PromptRepository: task.done = True self.db.add(task) self.db.commit() + + @staticmethod + def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]: + """ + Pick messages from a collection so that the result makes a linear conversation + starting from a message tree root and up to the given message. + Returns an ordered list of messages starting from the message tree root. + """ + if isinstance(messages, list): + messages = {m.id: m for m in messages} + if not isinstance(messages, dict): + # This should not normally happen + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + + conv = [last_message] + while conv[-1].parent_id: + if conv[-1].parent_id not in messages: + # Can't form a continuous conversation + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + + parent_message = messages[conv[-1].parent_id] + conv.append(parent_message) + + return list(reversed(conv)) + + def fetch_message_conversation(self, message: Post | UUID) -> list[Post]: + """ + Fetch a conversation from the tree root and up to this message. + """ + if isinstance(message, UUID): + message = self.fetch_post(message) + + tree_messages = self.fetch_thread(message.thread_id) + return self.trace_conversation(tree_messages, message) + + def fetch_message_tree(self, message: Post | UUID) -> list[Post]: + """ + Fetch message tree this message belongs to. + """ + if isinstance(message, UUID): + message = self.fetch_post(message) + return self.fetch_thread(message.thread_id) + + def fetch_message_children(self, message: Post | UUID) -> list[Post]: + """ + Get all direct children of this message + """ + if isinstance(message, Post): + message = message.id + + children = self.db.query(Post).filter(Post.parent_id == message).all() + return children + + @staticmethod + def trace_descendants(root: Post, messages: list[Post]) -> list[Post]: + children = defaultdict(list) + for msg in messages: + children[msg.parent_id].append(msg) + + def _traverse_subtree(m: Post): + for child in children[m.id]: + yield child + yield from _traverse_subtree(child) + + return list(_traverse_subtree(root)) + + def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]: + if isinstance(message, UUID): + message = self.fetch_post(message) + + desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth) + if max_depth is not None: + desc = desc.filter(Post.depth <= max_depth) + + desc = desc.all() + + return self.trace_descendants(message, desc) + + def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]: + tree = self.fetch_message_tree(message) + max_message = max(tree, key=lambda m: m.depth) + return self.trace_conversation(tree, max_message) + + def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]: + tree = self.fetch_message_tree(message) + max_message = max(tree, key=lambda m: m.children_count) + return max_message, [m for m in tree if m.parent_id == max_message.id] + + def query_messages( + self, + username: str = None, + api_client_id: str = None, + desc: bool = True, + max_count: int = 10, + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + only_roots: bool = False, + ) -> list[Post]: + messages = self.db.query(Post) + if username: + messages = messages.join(Person) + messages = messages.filter(Person.username == username) + if api_client_id: + messages = messages.filter(Post.api_client_id == api_client_id) + + if start_date: + messages = messages.filter(Post.created_date >= start_date) + if end_date: + messages = messages.filter(Post.created_date < end_date) + + if only_roots: + messages = messages.filter(Post.parent_id.is_(None)) + + if desc: + messages = messages.order_by(Post.created_date.desc()) + else: + messages = messages.order_by(Post.created_date.asc()) + + messages = messages.limit(max_count).all() + return messages + + def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True): + if isinstance(messages, (Post, UUID)): + messages = [messages] + + ids = [] + for message in messages: + if isinstance(message, UUID): + ids.append(message) + elif isinstance(message, Post): + ids.append(message.id) + else: + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + + query = update(Post).where(Post.id.in_(ids)).values(deleted=True) + self.db.execute(query) + + parent_ids = ids + if recursive: + while parent_ids: + query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id) + + parent_ids = self.db.execute(query).scalars().all() + + self.db.commit() diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 8fe8bdea..028154c9 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -38,6 +38,18 @@ class Conversation(BaseModel): messages: list[ConversationMessage] = [] +class Message(ConversationMessage): + id: UUID + parent_id: Optional[UUID] = None + + +class MessageTree(BaseModel): + """All messages belonging to the same message tree.""" + + id: UUID + messages: list[Message] = [] + + class TaskRequest(BaseModel): """The frontend asks the backend for a task."""