adjust names and types to new naming

This commit is contained in:
Igor Miagkov
2022-12-31 04:35:33 +04:00
committed by Andreas Köpf
parent 475f48b195
commit f126b21bb3
5 changed files with 83 additions and 81 deletions
@@ -5,7 +5,7 @@ from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
@@ -21,10 +21,10 @@ def get_message_by_frontend_id(
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id, fail_if_missing=True)
message = pr.fetch_message_by_frontend_message_id(message_id)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@@ -38,7 +38,7 @@ def get_conv_by_frontend_id(
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@@ -52,9 +52,9 @@ def get_tree_by_frontend_id(
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children")
@@ -65,7 +65,7 @@ def get_children_by_frontend_id(
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return [
protocol.Message(
@@ -84,8 +84,8 @@ def get_descendants_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
descendants = pr.fetch_post_descendants(message)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -98,8 +98,8 @@ def get_longest_conv_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -112,6 +112,6 @@ def get_max_children_by_frontend_id(
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post_by_frontend_post_id(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
+13 -13
View File
@@ -8,7 +8,7 @@ from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
@@ -61,9 +61,9 @@ def get_message(
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
if not isinstance(message.payload.payload, PostPayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_POST_ID)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Invalid message id", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@@ -89,9 +89,9 @@ def get_tree(
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
tree = pr.fetch_message_tree(message)
return utils.prepare_tree(tree, message.thread_id)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children")
@@ -119,8 +119,8 @@ def get_descendants(
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
descendants = pr.fetch_post_descendants(message)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -132,8 +132,8 @@ def get_longest_conv(
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
conv = pr.fetch_longest_conversation(message.thread_id)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -145,8 +145,8 @@ def get_max_children(
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_post(message_id)
message, children = pr.fetch_message_with_max_children(message.thread_id)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
+6 -6
View File
@@ -4,15 +4,15 @@ from http import HTTPStatus
from uuid import UUID
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import Post
from oasst_backend.models.db_payload import PostPayload
from oasst_backend.models import Message
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.schemas import protocol
def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, PostPayload):
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
@@ -21,10 +21,10 @@ def prepare_conversation(messages: list[Post]) -> protocol.Conversation:
return protocol.Conversation(messages=conv_messages)
def prepare_tree(tree: list[Post], tree_id: UUID) -> protocol.MessageTree:
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, PostPayload):
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
+3 -3
View File
@@ -54,7 +54,7 @@ class JournalWriter:
self.user = user
self.user_id = self.user.id if self.user else None
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
@@ -63,7 +63,7 @@ class JournalWriter:
message_id=message_id,
)
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
@@ -72,7 +72,7 @@ class JournalWriter:
message_id=message_id,
)
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
+46 -44
View File
@@ -26,7 +26,7 @@ class PromptRepository:
self.user_id = self.user.id if self.user else None
self.journal = JournalWriter(db, api_client, self.user)
def lookup_user(self, client_user: protocol_schema.User) -> User:
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
if not client_user:
return None
user: User = (
@@ -119,9 +119,7 @@ class PromptRepository:
)
return task
def store_text_reply(
self, text: str, frontend_message_id: str, user_frontend_message_id: str, role: str = None
) -> Message:
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
@@ -481,7 +479,7 @@ class PromptRepository:
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.POST_NOT_FOUND)
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
@@ -502,7 +500,7 @@ class PromptRepository:
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Post] | dict[UUID, Post], last_message: Post) -> list[Post]:
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
Pick messages from a collection so that the result makes a linear conversation
starting from a message tree root and up to the given message.
@@ -525,66 +523,68 @@ class PromptRepository:
return list(reversed(conv))
def fetch_message_conversation(self, message: Post | UUID) -> list[Post]:
def fetch_message_conversation(self, message: Message | UUID) -> list[Message]:
"""
Fetch a conversation from the tree root and up to this message.
"""
if isinstance(message, UUID):
message = self.fetch_post(message)
message = self.fetch_message(message)
tree_messages = self.fetch_thread(message.thread_id)
tree_messages = self.fetch_message_tree(message.message_tree_id)
return self.trace_conversation(tree_messages, message)
def fetch_message_tree(self, message: Post | UUID) -> list[Post]:
def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
"""
Fetch message tree this message belongs to.
"""
if isinstance(message, UUID):
message = self.fetch_post(message)
return self.fetch_thread(message.thread_id)
message = self.fetch_message(message)
return self.fetch_message_tree(message.message_tree_id)
def fetch_message_children(self, message: Post | UUID) -> list[Post]:
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
"""
Get all direct children of this message
"""
if isinstance(message, Post):
if isinstance(message, Message):
message = message.id
children = self.db.query(Post).filter(Post.parent_id == message).all()
children = self.db.query(Message).filter(Message.parent_id == message).all()
return children
@staticmethod
def trace_descendants(root: Post, messages: list[Post]) -> list[Post]:
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
children = defaultdict(list)
for msg in messages:
children[msg.parent_id].append(msg)
def _traverse_subtree(m: Post):
def _traverse_subtree(m: Message):
for child in children[m.id]:
yield child
yield from _traverse_subtree(child)
return list(_traverse_subtree(root))
def fetch_post_descendants(self, message: Post | UUID, max_depth: int = None) -> list[Post]:
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
if isinstance(message, UUID):
message = self.fetch_post(message)
message = self.fetch_message(message)
desc = self.db.query(Post).filter(Post.thread_id == message.thread_id, Post.depth > message.depth)
desc = self.db.query(Message).filter(
Message.message_tree_id == message.message_tree_id, Message.depth > message.depth
)
if max_depth is not None:
desc = desc.filter(Post.depth <= max_depth)
desc = desc.filter(Message.depth <= max_depth)
desc = desc.all()
return self.trace_descendants(message, desc)
def fetch_longest_conversation(self, message: Post | UUID) -> list[Post]:
tree = self.fetch_message_tree(message)
def fetch_longest_conversation(self, message: Message | UUID) -> list[Message]:
tree = self.fetch_tree_from_message(message)
max_message = max(tree, key=lambda m: m.depth)
return self.trace_conversation(tree, max_message)
def fetch_message_with_max_children(self, message: Post | UUID) -> tuple[Post, list[Post]]:
tree = self.fetch_message_tree(message)
def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Message, list[Message]]:
tree = self.fetch_tree_from_message(message)
max_message = max(tree, key=lambda m: m.children_count)
return max_message, [m for m in tree if m.parent_id == max_message.id]
@@ -599,7 +599,7 @@ class PromptRepository:
end_date: Optional[datetime.datetime] = None,
only_roots: bool = False,
deleted: Optional[bool] = None,
) -> list[Post]:
) -> list[Message]:
if not self.api_client.trusted and not api_client_id:
# Let unprivileged api clients query their own messages without api_client_id being set
api_client_id = self.api_client.id
@@ -608,30 +608,30 @@ class PromptRepository:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
messages = self.db.query(Post)
messages = self.db.query(Message)
if user_id:
messages = messages.filter(Post.person_id == user_id)
messages = messages.filter(Message.user_id == user_id)
if username:
messages = messages.join(Person)
messages = messages.filter(Person.username == username)
messages = messages.join(User)
messages = messages.filter(User.username == username)
if api_client_id:
messages = messages.filter(Post.api_client_id == api_client_id)
messages = messages.filter(Message.api_client_id == api_client_id)
if start_date:
messages = messages.filter(Post.created_date >= start_date)
messages = messages.filter(Message.created_date >= start_date)
if end_date:
messages = messages.filter(Post.created_date < end_date)
messages = messages.filter(Message.created_date < end_date)
if only_roots:
messages = messages.filter(Post.parent_id.is_(None))
messages = messages.filter(Message.parent_id.is_(None))
if deleted is not None:
messages = messages.filter(Post.deleted == deleted)
messages = messages.filter(Message.deleted == deleted)
if desc:
messages = messages.order_by(Post.created_date.desc())
messages = messages.order_by(Message.created_date.desc())
else:
messages = messages.order_by(Post.created_date.asc())
messages = messages.order_by(Message.created_date.asc())
if limit is not None:
messages = messages.limit(limit)
@@ -639,34 +639,36 @@ class PromptRepository:
# TODO: Pagination could be great at some point
return messages.all()
def mark_messages_deleted(self, messages: Post | UUID | list[Post | UUID], recursive: bool = True):
if isinstance(messages, (Post, UUID)):
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
if isinstance(messages, (Message, UUID)):
messages = [messages]
ids = []
for message in messages:
if isinstance(message, UUID):
ids.append(message)
elif isinstance(message, Post):
elif isinstance(message, Message):
ids.append(message.id)
else:
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
query = update(Post).where(Post.id.in_(ids)).values(deleted=True)
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
self.db.execute(query)
parent_ids = ids
if recursive:
while parent_ids:
query = update(Post).filter(Post.parent_id.in_(parent_ids)).values(deleted=True).returning(Post.id)
query = (
update(Message).filter(Message.parent_id.in_(parent_ids)).values(deleted=True).returning(Message.id)
)
parent_ids = self.db.execute(query).scalars().all()
self.db.commit()
def get_stats(self):
deleted = self.db.query(Post.deleted, func.count()).group_by(Post.deleted)
nthreads = self.db.query(None, func.count(Post.id)).filter(Post.parent_id.is_(None))
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
query = deleted.union_all(nthreads)
result = {k: v for k, v in query.all()}
return {