diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index acdd388a..c53f6bae 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -5,7 +5,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.models import ApiClient -from oasst_backend.models.db_payload import PostPayload +from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol from sqlmodel import Session @@ -21,10 +21,10 @@ def get_message_by_frontend_id( Get a message by its frontend ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True) + message = pr.fetch_message_by_frontend_message_id(message_id) - if not isinstance(message.payload.payload, PostPayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + if not isinstance(message.payload.payload, MessagePayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) @@ -38,7 +38,7 @@ def get_conv_by_frontend_id( """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) + message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_conversation(message) return utils.prepare_conversation(messages) @@ -52,9 +52,9 @@ def get_tree_by_frontend_id( Message is identified by its frontend ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - tree = pr.fetch_message_tree(message) - return utils.prepare_tree(tree, message.thread_id) + message = pr.fetch_message_by_frontend_message_id(message_id) + tree = pr.fetch_message_tree(message.message_tree_id) + return utils.prepare_tree(tree, message.message_tree_id) @router.get("/{message_id}/children") @@ -65,7 +65,7 @@ def get_children_by_frontend_id( Get all messages belonging to the same message tree. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) + message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_children(message.id) return [ protocol.Message( @@ -84,8 +84,8 @@ def get_descendants_by_frontend_id( The message is identified by its frontend ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - descendants = pr.fetch_post_descendants(message) + message = pr.fetch_message_by_frontend_message_id(message_id) + descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -98,8 +98,8 @@ def get_longest_conv_by_frontend_id( The message is identified by its frontend ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - conv = pr.fetch_longest_conversation(message.thread_id) + message = pr.fetch_message_by_frontend_message_id(message_id) + conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -112,6 +112,6 @@ def get_max_children_by_frontend_id( The message is identified by its frontend ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post_by_frontend_post_id(message_id) - message, children = pr.fetch_message_with_max_children(message.thread_id) + message = pr.fetch_message_by_frontend_message_id(message_id) + message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index bd8e7f12..a812ef61 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -8,7 +8,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.models import ApiClient -from oasst_backend.models.db_payload import PostPayload +from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol from sqlmodel import Session @@ -61,9 +61,9 @@ def get_message( Get a message by its internal ID. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - if not isinstance(message.payload.payload, PostPayload): - raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID) + message = pr.fetch_message(message_id) + if not isinstance(message.payload.payload, MessagePayload): + raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) @@ -89,9 +89,9 @@ def get_tree( Get all messages belonging to the same message tree. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - tree = pr.fetch_message_tree(message) - return utils.prepare_tree(tree, message.thread_id) + message = pr.fetch_message(message_id) + tree = pr.fetch_message_tree(message.message_tree_id) + return utils.prepare_tree(tree, message.message_tree_id) @router.get("/{message_id}/children") @@ -119,8 +119,8 @@ def get_descendants( Get a subtree which starts with this message. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - descendants = pr.fetch_post_descendants(message) + message = pr.fetch_message(message_id) + descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -132,8 +132,8 @@ def get_longest_conv( Get the longest conversation from the tree of the message. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - conv = pr.fetch_longest_conversation(message.thread_id) + message = pr.fetch_message(message_id) + conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -145,8 +145,8 @@ def get_max_children( Get message with the most children from the tree of the provided message. """ pr = PromptRepository(db, api_client, user=None) - message = pr.fetch_post(message_id) - message, children = pr.fetch_message_with_max_children(message.thread_id) + message = pr.fetch_message(message_id) + message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 48f11038..2dac7947 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -4,15 +4,15 @@ from http import HTTPStatus from uuid import UUID from oasst_backend.exceptions import OasstError, OasstErrorCode -from oasst_backend.models import Post -from oasst_backend.models.db_payload import PostPayload +from oasst_backend.models import Message +from oasst_backend.models.db_payload import MessagePayload from oasst_shared.schemas import protocol -def prepare_conversation(messages: list[Post]) -> protocol.Conversation: +def prepare_conversation(messages: list[Message]) -> protocol.Conversation: conv_messages = [] for message in messages: - if not isinstance(message.payload.payload, PostPayload): + if not isinstance(message.payload.payload, MessagePayload): raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) conv_messages.append( protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) @@ -21,10 +21,10 @@ def prepare_conversation(messages: list[Post]) -> protocol.Conversation: return protocol.Conversation(messages=conv_messages) -def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree: +def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree: tree_messages = [] for message in tree: - if not isinstance(message.payload.payload, PostPayload): + if not isinstance(message.payload.payload, MessagePayload): raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) tree_messages.append( protocol.Message( diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py index 64327366..415d5a47 100644 --- a/backend/oasst_backend/journal_writer.py +++ b/backend/oasst_backend/journal_writer.py @@ -54,7 +54,7 @@ class JournalWriter: self.user = user self.user_id = self.user.id if self.user else None - def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal: + def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal: return self.log( task_type=task.payload_type, event_type=JournalEventType.text_reply_to_message, @@ -63,7 +63,7 @@ class JournalWriter: message_id=message_id, ) - def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal: + def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal: return self.log( task_type=task.payload_type, event_type=JournalEventType.message_rating, @@ -72,7 +72,7 @@ class JournalWriter: message_id=message_id, ) - def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal: + def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal: return self.log( task_type=task.payload_type, event_type=JournalEventType.message_ranking, diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8df17347..155a6d2a 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -26,7 +26,7 @@ class PromptRepository: self.user_id = self.user.id if self.user else None self.journal = JournalWriter(db, api_client, self.user) - def lookup_user(self, client_user: protocol_schema.User) -> User: + def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]: if not client_user: return None user: User = ( @@ -119,9 +119,7 @@ class PromptRepository: ) return task - def store_text_reply( - self, text: str, frontend_message_id: str, user_frontend_message_id: str, role: str = None - ) -> Message: + def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: self.validate_frontend_message_id(frontend_message_id) self.validate_frontend_message_id(user_frontend_message_id) @@ -481,7 +479,7 @@ class PromptRepository: def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: - raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND) + raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND) return message def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): @@ -502,7 +500,7 @@ class PromptRepository: self.db.commit() @staticmethod - def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]: + def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """ Pick messages from a collection so that the result makes a linear conversation starting from a message tree root and up to the given message. @@ -525,66 +523,68 @@ class PromptRepository: return list(reversed(conv)) - def fetch_message_conversation(self, message: Post | UUID) -> list[Post]: + def fetch_message_conversation(self, message: Message | UUID) -> list[Message]: """ Fetch a conversation from the tree root and up to this message. """ if isinstance(message, UUID): - message = self.fetch_post(message) + message = self.fetch_message(message) - tree_messages = self.fetch_thread(message.thread_id) + tree_messages = self.fetch_message_tree(message.message_tree_id) return self.trace_conversation(tree_messages, message) - def fetch_message_tree(self, message: Post | UUID) -> list[Post]: + def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]: """ Fetch message tree this message belongs to. """ if isinstance(message, UUID): - message = self.fetch_post(message) - return self.fetch_thread(message.thread_id) + message = self.fetch_message(message) + return self.fetch_message_tree(message.message_tree_id) - def fetch_message_children(self, message: Post | UUID) -> list[Post]: + def fetch_message_children(self, message: Message | UUID) -> list[Message]: """ Get all direct children of this message """ - if isinstance(message, Post): + if isinstance(message, Message): message = message.id - children = self.db.query(Post).filter(Post.parent_id == message).all() + children = self.db.query(Message).filter(Message.parent_id == message).all() return children @staticmethod - def trace_descendants(root: Post, messages: list[Post]) -> list[Post]: + def trace_descendants(root: Message, messages: list[Message]) -> list[Message]: children = defaultdict(list) for msg in messages: children[msg.parent_id].append(msg) - def _traverse_subtree(m: Post): + def _traverse_subtree(m: Message): for child in children[m.id]: yield child yield from _traverse_subtree(child) return list(_traverse_subtree(root)) - def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]: + def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]: if isinstance(message, UUID): - message = self.fetch_post(message) + message = self.fetch_message(message) - desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth) + desc = self.db.query(Message).filter( + Message.message_tree_id == message.message_tree_id, Message.depth > message.depth + ) if max_depth is not None: - desc = desc.filter(Post.depth <= max_depth) + desc = desc.filter(Message.depth <= max_depth) desc = desc.all() return self.trace_descendants(message, desc) - def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]: - tree = self.fetch_message_tree(message) + def fetch_longest_conversation(self, message: Message | UUID) -> list[Message]: + tree = self.fetch_tree_from_message(message) max_message = max(tree, key=lambda m: m.depth) return self.trace_conversation(tree, max_message) - def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]: - tree = self.fetch_message_tree(message) + def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Message, list[Message]]: + tree = self.fetch_tree_from_message(message) max_message = max(tree, key=lambda m: m.children_count) return max_message, [m for m in tree if m.parent_id == max_message.id] @@ -599,7 +599,7 @@ class PromptRepository: end_date: Optional[datetime.datetime] = None, only_roots: bool = False, deleted: Optional[bool] = None, - ) -> list[Post]: + ) -> list[Message]: if not self.api_client.trusted and not api_client_id: # Let unprivileged api clients query their own messages without api_client_id being set api_client_id = self.api_client.id @@ -608,30 +608,30 @@ class PromptRepository: # Unprivileged api client asks for foreign messages raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - messages = self.db.query(Post) + messages = self.db.query(Message) if user_id: - messages = messages.filter(Post.person_id == user_id) + messages = messages.filter(Message.user_id == user_id) if username: - messages = messages.join(Person) - messages = messages.filter(Person.username == username) + messages = messages.join(User) + messages = messages.filter(User.username == username) if api_client_id: - messages = messages.filter(Post.api_client_id == api_client_id) + messages = messages.filter(Message.api_client_id == api_client_id) if start_date: - messages = messages.filter(Post.created_date >= start_date) + messages = messages.filter(Message.created_date >= start_date) if end_date: - messages = messages.filter(Post.created_date < end_date) + messages = messages.filter(Message.created_date < end_date) if only_roots: - messages = messages.filter(Post.parent_id.is_(None)) + messages = messages.filter(Message.parent_id.is_(None)) if deleted is not None: - messages = messages.filter(Post.deleted == deleted) + messages = messages.filter(Message.deleted == deleted) if desc: - messages = messages.order_by(Post.created_date.desc()) + messages = messages.order_by(Message.created_date.desc()) else: - messages = messages.order_by(Post.created_date.asc()) + messages = messages.order_by(Message.created_date.asc()) if limit is not None: messages = messages.limit(limit) @@ -639,34 +639,36 @@ class PromptRepository: # TODO: Pagination could be great at some point return messages.all() - def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True): - if isinstance(messages, (Post, UUID)): + def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): + if isinstance(messages, (Message, UUID)): messages = [messages] ids = [] for message in messages: if isinstance(message, UUID): ids.append(message) - elif isinstance(message, Post): + elif isinstance(message, Message): ids.append(message.id) else: raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) - query = update(Post).where(Post.id.in_(ids)).values(deleted=True) + query = update(Message).where(Message.id.in_(ids)).values(deleted=True) self.db.execute(query) parent_ids = ids if recursive: while parent_ids: - query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id) + query = ( + update(Message).filter(Message.parent_id.in_(parent_ids)).values(deleted=True).returning(Message.id) + ) parent_ids = self.db.execute(query).scalars().all() self.db.commit() def get_stats(self): - deleted = self.db.query(Post.deleted, func.count()).group_by(Post.deleted) - nthreads = self.db.query(None, func.count(Post.id)).filter(Post.parent_id.is_(None)) + deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted) + nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None)) query = deleted.union_all(nthreads) result = {k: v for k, v in query.all()} return {