mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
adjust names and types to new naming
This commit is contained in:
committed by
Andreas Köpf
parent
475f48b195
commit
f126b21bb3
@@ -5,7 +5,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
@@ -21,10 +21,10 @@ def get_message_by_frontend_id(
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_conv_by_frontend_id(
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
@@ -52,9 +52,9 @@ def get_tree_by_frontend_id(
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return utils.prepare_tree(tree, message.thread_id)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@@ -65,7 +65,7 @@ def get_children_by_frontend_id(
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return [
|
||||
protocol.Message(
|
||||
@@ -84,8 +84,8 @@ def get_descendants_by_frontend_id(
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@@ -98,8 +98,8 @@ def get_longest_conv_by_frontend_id(
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@@ -112,6 +112,6 @@ def get_max_children_by_frontend_id(
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post_by_frontend_post_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
@@ -8,7 +8,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
@@ -61,9 +61,9 @@ def get_message(
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
|
||||
message = pr.fetch_message(message_id)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
|
||||
@@ -89,9 +89,9 @@ def get_tree(
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
tree = pr.fetch_message_tree(message)
|
||||
return utils.prepare_tree(tree, message.thread_id)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@@ -119,8 +119,8 @@ def get_descendants(
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
descendants = pr.fetch_post_descendants(message)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@@ -132,8 +132,8 @@ def get_longest_conv(
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.thread_id)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@@ -145,8 +145,8 @@ def get_max_children(
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_post(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.thread_id)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
|
||||
@@ -4,15 +4,15 @@ from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import Post
|
||||
from oasst_backend.models.db_payload import PostPayload
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
@@ -21,10 +21,10 @@ def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
|
||||
def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
|
||||
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, PostPayload):
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(
|
||||
protocol.Message(
|
||||
|
||||
@@ -54,7 +54,7 @@ class JournalWriter:
|
||||
self.user = user
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
|
||||
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_message,
|
||||
@@ -63,7 +63,7 @@ class JournalWriter:
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
|
||||
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_rating,
|
||||
@@ -72,7 +72,7 @@ class JournalWriter:
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
|
||||
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_ranking,
|
||||
|
||||
@@ -26,7 +26,7 @@ class PromptRepository:
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> User:
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
@@ -119,9 +119,7 @@ class PromptRepository:
|
||||
)
|
||||
return task
|
||||
|
||||
def store_text_reply(
|
||||
self, text: str, frontend_message_id: str, user_frontend_message_id: str, role: str = None
|
||||
) -> Message:
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -481,7 +479,7 @@ class PromptRepository:
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND)
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
@@ -502,7 +500,7 @@ class PromptRepository:
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]:
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
Pick messages from a collection so that the result makes a linear conversation
|
||||
starting from a message tree root and up to the given message.
|
||||
@@ -525,66 +523,68 @@ class PromptRepository:
|
||||
|
||||
return list(reversed(conv))
|
||||
|
||||
def fetch_message_conversation(self, message: Post | UUID) -> list[Post]:
|
||||
def fetch_message_conversation(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Fetch a conversation from the tree root and up to this message.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
message = self.fetch_message(message)
|
||||
|
||||
tree_messages = self.fetch_thread(message.thread_id)
|
||||
tree_messages = self.fetch_message_tree(message.message_tree_id)
|
||||
return self.trace_conversation(tree_messages, message)
|
||||
|
||||
def fetch_message_tree(self, message: Post | UUID) -> list[Post]:
|
||||
def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Fetch message tree this message belongs to.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
return self.fetch_thread(message.thread_id)
|
||||
message = self.fetch_message(message)
|
||||
return self.fetch_message_tree(message.message_tree_id)
|
||||
|
||||
def fetch_message_children(self, message: Post | UUID) -> list[Post]:
|
||||
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Get all direct children of this message
|
||||
"""
|
||||
if isinstance(message, Post):
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
|
||||
children = self.db.query(Post).filter(Post.parent_id == message).all()
|
||||
children = self.db.query(Message).filter(Message.parent_id == message).all()
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
def trace_descendants(root: Post, messages: list[Post]) -> list[Post]:
|
||||
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
|
||||
children = defaultdict(list)
|
||||
for msg in messages:
|
||||
children[msg.parent_id].append(msg)
|
||||
|
||||
def _traverse_subtree(m: Post):
|
||||
def _traverse_subtree(m: Message):
|
||||
for child in children[m.id]:
|
||||
yield child
|
||||
yield from _traverse_subtree(child)
|
||||
|
||||
return list(_traverse_subtree(root))
|
||||
|
||||
def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]:
|
||||
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_post(message)
|
||||
message = self.fetch_message(message)
|
||||
|
||||
desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth)
|
||||
desc = self.db.query(Message).filter(
|
||||
Message.message_tree_id == message.message_tree_id, Message.depth > message.depth
|
||||
)
|
||||
if max_depth is not None:
|
||||
desc = desc.filter(Post.depth <= max_depth)
|
||||
desc = desc.filter(Message.depth <= max_depth)
|
||||
|
||||
desc = desc.all()
|
||||
|
||||
return self.trace_descendants(message, desc)
|
||||
|
||||
def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]:
|
||||
tree = self.fetch_message_tree(message)
|
||||
def fetch_longest_conversation(self, message: Message | UUID) -> list[Message]:
|
||||
tree = self.fetch_tree_from_message(message)
|
||||
max_message = max(tree, key=lambda m: m.depth)
|
||||
return self.trace_conversation(tree, max_message)
|
||||
|
||||
def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]:
|
||||
tree = self.fetch_message_tree(message)
|
||||
def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Message, list[Message]]:
|
||||
tree = self.fetch_tree_from_message(message)
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
@@ -599,7 +599,7 @@ class PromptRepository:
|
||||
end_date: Optional[datetime.datetime] = None,
|
||||
only_roots: bool = False,
|
||||
deleted: Optional[bool] = None,
|
||||
) -> list[Post]:
|
||||
) -> list[Message]:
|
||||
if not self.api_client.trusted and not api_client_id:
|
||||
# Let unprivileged api clients query their own messages without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
@@ -608,30 +608,30 @@ class PromptRepository:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
messages = self.db.query(Post)
|
||||
messages = self.db.query(Message)
|
||||
if user_id:
|
||||
messages = messages.filter(Post.person_id == user_id)
|
||||
messages = messages.filter(Message.user_id == user_id)
|
||||
if username:
|
||||
messages = messages.join(Person)
|
||||
messages = messages.filter(Person.username == username)
|
||||
messages = messages.join(User)
|
||||
messages = messages.filter(User.username == username)
|
||||
if api_client_id:
|
||||
messages = messages.filter(Post.api_client_id == api_client_id)
|
||||
messages = messages.filter(Message.api_client_id == api_client_id)
|
||||
|
||||
if start_date:
|
||||
messages = messages.filter(Post.created_date >= start_date)
|
||||
messages = messages.filter(Message.created_date >= start_date)
|
||||
if end_date:
|
||||
messages = messages.filter(Post.created_date < end_date)
|
||||
messages = messages.filter(Message.created_date < end_date)
|
||||
|
||||
if only_roots:
|
||||
messages = messages.filter(Post.parent_id.is_(None))
|
||||
messages = messages.filter(Message.parent_id.is_(None))
|
||||
|
||||
if deleted is not None:
|
||||
messages = messages.filter(Post.deleted == deleted)
|
||||
messages = messages.filter(Message.deleted == deleted)
|
||||
|
||||
if desc:
|
||||
messages = messages.order_by(Post.created_date.desc())
|
||||
messages = messages.order_by(Message.created_date.desc())
|
||||
else:
|
||||
messages = messages.order_by(Post.created_date.asc())
|
||||
messages = messages.order_by(Message.created_date.asc())
|
||||
|
||||
if limit is not None:
|
||||
messages = messages.limit(limit)
|
||||
@@ -639,34 +639,36 @@ class PromptRepository:
|
||||
# TODO: Pagination could be great at some point
|
||||
return messages.all()
|
||||
|
||||
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
|
||||
if isinstance(messages, (Post, UUID)):
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
if isinstance(messages, (Message, UUID)):
|
||||
messages = [messages]
|
||||
|
||||
ids = []
|
||||
for message in messages:
|
||||
if isinstance(message, UUID):
|
||||
ids.append(message)
|
||||
elif isinstance(message, Post):
|
||||
elif isinstance(message, Message):
|
||||
ids.append(message.id)
|
||||
else:
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
query = update(Post).where(Post.id.in_(ids)).values(deleted=True)
|
||||
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
|
||||
self.db.execute(query)
|
||||
|
||||
parent_ids = ids
|
||||
if recursive:
|
||||
while parent_ids:
|
||||
query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id)
|
||||
query = (
|
||||
update(Message).filter(Message.parent_id.in_(parent_ids)).values(deleted=True).returning(Message.id)
|
||||
)
|
||||
|
||||
parent_ids = self.db.execute(query).scalars().all()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_stats(self):
|
||||
deleted = self.db.query(Post.deleted, func.count()).group_by(Post.deleted)
|
||||
nthreads = self.db.query(None, func.count(Post.id)).filter(Post.parent_id.is_(None))
|
||||
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
|
||||
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
|
||||
query = deleted.union_all(nthreads)
|
||||
result = {k: v for k, v in query.all()}
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user