mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
management api
This commit is contained in:
committed by
Andreas Köpf
parent
e485ebcd43
commit
13d01b5a2f
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add deleted field to post
|
||||
|
||||
Revision ID: 6cb49da61b74
|
||||
Revises: 73ce3675c1f5
|
||||
Create Date: 2022-12-30 06:54:47.110204
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6cb49da61b74"
|
||||
down_revision = "73ce3675c1f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("post", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("post", "deleted")
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,7 +4,7 @@ from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi import Depends, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
@@ -64,3 +64,24 @@ def api_auth(
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
def get_api_client(
|
||||
api_key: APIKey = Depends(get_api_key),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return api_auth(api_key, db)
|
||||
|
||||
|
||||
def get_trusted_api_client(
|
||||
api_key: APIKey = Depends(get_api_key),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
client = api_auth(api_key, db)
|
||||
if not client.trusted:
|
||||
raise OasstError(
|
||||
"Forbidden",
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
return client
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import tasks, text_labels
|
||||
from oasst_backend.api.v1 import management, tasks, text_labels
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
|
||||
api_router.include_router(management.router, prefix="/management", tags=["management"])
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient, Post
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _prepare_conversation(messages: list[Post]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
|
||||
def _prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(
|
||||
protocol.Message(
|
||||
id=message.id,
|
||||
parent_id=message.parent_id,
|
||||
text=message.payload.payload.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
)
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
|
||||
@router.get("/message")
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=25),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
if not api_client.trusted and (api_client_id != api_client.id):
|
||||
# Unprivileged api client asks for foreign messages
|
||||
return []
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
max_count=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/message/{message_id}")
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
post = pr.fetch_post(message_id)
|
||||
if not isinstance(post.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=post.payload.payload.text, is_assistant=(post.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}")
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
|
||||
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/conversation")
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return _prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/conversation")
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return _prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/tree")
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return _prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/tree")
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return _prepare_tree(tree, message.thread_id)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/children")
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
return pr.fetch_message_children(message_id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/children")
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
return pr.fetch_message_children(message)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/descendants")
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return _prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/descendants")
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
return _prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return _prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
return _prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/message/{message_id}/max_children_in_tree")
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return _prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.get("/frontend_message/{message_id}/max_children_in_tree")
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
return _prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/message/{message_id}")
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
|
||||
@router.delete("/user/{username}/message")
|
||||
def mark_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
SERVER_ERROR = 3
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
@@ -34,3 +35,4 @@ class Message(SQLModel, table=True):
|
||||
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -10,6 +13,7 @@ from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
@@ -492,3 +496,148 @@ class PromptRepository:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]:
|
||||
"""
|
||||
Pick messages from a collection so that the result makes a linear conversation
|
||||
starting from a message tree root and up to the given message.
|
||||
Returns an ordered list of messages starting from the message tree root.
|
||||
"""
|
||||
if isinstance(messages, list):
|
||||
messages = {m.id: m for m in messages}
|
||||
if not isinstance(messages, dict):
|
||||
# This should not normally happen
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
conv = [last_message]
|
||||
while conv[-1].parent_id:
|
||||
if conv[-1].parent_id not in messages:
|
||||
# Can't form a continuous conversation
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
parent_message = messages[conv[-1].parent_id]
|
||||
conv.append(parent_message)
|
||||
|
||||
return list(reversed(conv))
|
||||
|
||||
def fetch_message_conversation(self, message: Post | UUID) -> list[Post]:
|
||||
"""
|
||||
Fetch a conversation from the tree root and up to this message.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
|
||||
tree_messages = self.fetch_thread(message.thread_id)
|
||||
return self.trace_conversation(tree_messages, message)
|
||||
|
||||
def fetch_message_tree(self, message: Post | UUID) -> list[Post]:
|
||||
"""
|
||||
Fetch message tree this message belongs to.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
return self.fetch_thread(message.thread_id)
|
||||
|
||||
def fetch_message_children(self, message: Post | UUID) -> list[Post]:
|
||||
"""
|
||||
Get all direct children of this message
|
||||
"""
|
||||
if isinstance(message, Post):
|
||||
message = message.id
|
||||
|
||||
children = self.db.query(Post).filter(Post.parent_id == message).all()
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
def trace_descendants(root: Post, messages: list[Post]) -> list[Post]:
|
||||
children = defaultdict(list)
|
||||
for msg in messages:
|
||||
children[msg.parent_id].append(msg)
|
||||
|
||||
def _traverse_subtree(m: Post):
|
||||
for child in children[m.id]:
|
||||
yield child
|
||||
yield from _traverse_subtree(child)
|
||||
|
||||
return list(_traverse_subtree(root))
|
||||
|
||||
def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]:
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
|
||||
desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth)
|
||||
if max_depth is not None:
|
||||
desc = desc.filter(Post.depth <= max_depth)
|
||||
|
||||
desc = desc.all()
|
||||
|
||||
return self.trace_descendants(message, desc)
|
||||
|
||||
def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]:
|
||||
tree = self.fetch_message_tree(message)
|
||||
max_message = max(tree, key=lambda m: m.depth)
|
||||
return self.trace_conversation(tree, max_message)
|
||||
|
||||
def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]:
|
||||
tree = self.fetch_message_tree(message)
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
def query_messages(
|
||||
self,
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
desc: bool = True,
|
||||
max_count: int = 10,
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
) -> list[Post]:
|
||||
messages = self.db.query(Post)
|
||||
if username:
|
||||
messages = messages.join(Person)
|
||||
messages = messages.filter(Person.username == username)
|
||||
if api_client_id:
|
||||
messages = messages.filter(Post.api_client_id == api_client_id)
|
||||
|
||||
if start_date:
|
||||
messages = messages.filter(Post.created_date >= start_date)
|
||||
if end_date:
|
||||
messages = messages.filter(Post.created_date < end_date)
|
||||
|
||||
if only_roots:
|
||||
messages = messages.filter(Post.parent_id.is_(None))
|
||||
|
||||
if desc:
|
||||
messages = messages.order_by(Post.created_date.desc())
|
||||
else:
|
||||
messages = messages.order_by(Post.created_date.asc())
|
||||
|
||||
messages = messages.limit(max_count).all()
|
||||
return messages
|
||||
|
||||
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
|
||||
if isinstance(messages, (Post, UUID)):
|
||||
messages = [messages]
|
||||
|
||||
ids = []
|
||||
for message in messages:
|
||||
if isinstance(message, UUID):
|
||||
ids.append(message)
|
||||
elif isinstance(message, Post):
|
||||
ids.append(message.id)
|
||||
else:
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
query = update(Post).where(Post.id.in_(ids)).values(deleted=True)
|
||||
self.db.execute(query)
|
||||
|
||||
parent_ids = ids
|
||||
if recursive:
|
||||
while parent_ids:
|
||||
query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id)
|
||||
|
||||
parent_ids = self.db.execute(query).scalars().all()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
@@ -38,6 +38,18 @@ class Conversation(BaseModel):
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
|
||||
class Message(ConversationMessage):
|
||||
id: UUID
|
||||
parent_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class MessageTree(BaseModel):
|
||||
"""All messages belonging to the same message tree."""
|
||||
|
||||
id: UUID
|
||||
messages: list[Message] = []
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user