From 8906854dbfbf0a3f9bb0c9ce2e53d0f996f534c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 8 Jan 2023 19:08:47 +0100 Subject: [PATCH] Extract UserRepository and TaskRepository from PromptRepository * Extract classes UserRepository and TaskRepository from PromptRepository * move close_task() to TaskRepository and get_user_leaderboard to UserRepository() * Use UserRepository in leaderboards endpoint, add type annotation to leaderboards endpoint --- backend/main.py | 17 +- .../oasst_backend/api/v1/frontend_messages.py | 14 +- .../oasst_backend/api/v1/frontend_users.py | 4 +- backend/oasst_backend/api/v1/leaderboards.py | 15 +- backend/oasst_backend/api/v1/messages.py | 18 +- backend/oasst_backend/api/v1/stats.py | 2 +- backend/oasst_backend/api/v1/tasks.py | 20 +- backend/oasst_backend/api/v1/text_labels.py | 2 +- backend/oasst_backend/api/v1/users.py | 4 +- .../models/message_tree_state.py | 45 ++- backend/oasst_backend/prompt_repository.py | 326 ++++-------------- backend/oasst_backend/task_repository.py | 199 +++++++++++ backend/oasst_backend/user_repository.py | 64 ++++ 13 files changed, 409 insertions(+), 321 deletions(-) create mode 100644 backend/oasst_backend/task_repository.py create mode 100644 backend/oasst_backend/user_repository.py diff --git a/backend/main.py b/backend/main.py index 1c93fc9f..b84a2d9e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router from oasst_backend.config import settings from oasst_backend.database import engine -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA: with Session(engine) as db: api_client = get_dummy_api_client(db) dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - pr = PromptRepository(db=db, api_client=api_client, user=dummy_user) + + ur = UserRepository(db=db, api_client=api_client) + tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur) + pr = PromptRepository( + db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr + ) with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: dummy_messages_raw = json.load(f) @@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA: dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] for msg in dummy_messages: - task = pr.fetch_task_by_frontend_message_id(msg.task_message_id) + task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) if task and not task.ack: logger.warning("Deleting unacknowledged seed data task") db.delete(task) task = None if not task: if msg.parent_message_id is None: - task = pr.store_task( + task = tr.store_task( protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: @@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA: for cmsg in conversation_messages ] ) - task = pr.store_task( + task = tr.store_task( protocol_schema.AssistantReplyTask(conversation=conversation), message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) - pr.bind_frontend_message_id(task.id, msg.task_message_id) + tr.bind_frontend_message_id(task.id, msg.task_message_id) message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id) logger.info( diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index 420f0d1b..f149bebb 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -16,7 +16,7 @@ def get_message_by_frontend_id( """ Get a message by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) return utils.prepare_message(message) @@ -29,7 +29,7 @@ def get_conv_by_frontend_id( Get a conversation from the tree root and up to the message with given frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_conversation(message) return utils.prepare_conversation(messages) @@ -43,7 +43,7 @@ def get_tree_by_frontend_id( Get all messages belonging to the same message tree. Message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) tree = pr.fetch_message_tree(message.message_tree_id) return utils.prepare_tree(tree, message.message_tree_id) @@ -56,7 +56,7 @@ def get_children_by_frontend_id( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_children(message.id) return utils.prepare_message_list(messages) @@ -70,7 +70,7 @@ def get_descendants_by_frontend_id( Get a subtree which starts with this message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id( Get the longest conversation from the tree of the message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -98,7 +98,7 @@ def get_max_children_by_frontend_id( Get message with the most children from the tree of the provided message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 0a745462..8d56b7f9 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -29,7 +29,7 @@ def query_frontend_user_messages( """ Query frontend user messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( username=username, api_client_id=api_client_id, @@ -47,6 +47,6 @@ def query_frontend_user_messages( def mark_frontend_user_messages_deleted( username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) messages = pr.query_messages(username=username, api_client_id=api_client.id) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 4202edad..46aea637 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -1,7 +1,8 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository +from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session router = APIRouter() @@ -11,15 +12,15 @@ router = APIRouter() def get_assistant_leaderboard( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), -): - pr = PromptRepository(db, api_client, None) - return pr.get_user_leaderboard(role="assistant") +) -> LeaderboardStats: + ur = UserRepository(db, api_client) + return ur.get_user_leaderboard(role="assistant") @router.get("/create/prompter") def get_prompter_leaderboard( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), -): - pr = PromptRepository(db, api_client, None) - return pr.get_user_leaderboard(role="prompter") +) -> LeaderboardStats: + ur = UserRepository(db, api_client) + return ur.get_user_leaderboard(role="prompter") diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 7a2fd2e9..6229e20c 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -29,7 +29,7 @@ def query_messages( """ Query messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( username=username, api_client_id=api_client_id, @@ -51,7 +51,7 @@ def get_message( """ Get a message by its internal ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) return utils.prepare_message(message) @@ -64,7 +64,7 @@ def get_conv( Get a conversation from the tree root and up to the message with given internal ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.fetch_message_conversation(message_id) return utils.prepare_conversation(messages) @@ -76,7 +76,7 @@ def get_tree( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) tree = pr.fetch_message_tree(message.message_tree_id) return utils.prepare_tree(tree, message.message_tree_id) @@ -89,7 +89,7 @@ def get_children( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.fetch_message_children(message_id) return utils.prepare_message_list(messages) @@ -101,7 +101,7 @@ def get_descendants( """ Get a subtree which starts with this message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -114,7 +114,7 @@ def get_longest_conv( """ Get the longest conversation from the tree of the message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -127,7 +127,7 @@ def get_max_children( """ Get message with the most children from the tree of the provided message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) @@ -137,5 +137,5 @@ def get_max_children( def mark_message_deleted( message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) pr.mark_messages_deleted(message_id) diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py index a54aa07b..1aaffb1b 100644 --- a/backend/oasst_backend/api/v1/stats.py +++ b/backend/oasst_backend/api/v1/stats.py @@ -13,5 +13,5 @@ def get_message_stats( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) return pr.get_stats() diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index adfb2907..eb10dc00 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.prompt_repository import PromptRepository, TaskRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -190,9 +190,9 @@ def request_task( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, request.user) + pr = PromptRepository(db, api_client, client_user=request.user) task, message_tree_id, parent_message_id = generate_task(request, pr) - pr.store_task(task, message_tree_id, parent_message_id, request.collective) + pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective) except OasstError: raise @@ -217,11 +217,11 @@ def tasks_acknowledge( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) # here we store the message id in the database for the task logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") - pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id) + pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id) except OasstError: raise @@ -245,8 +245,8 @@ def tasks_acknowledge_failure( try: logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.") api_client = deps.api_auth(api_key, db) - pr = PromptRepository(db, api_client, user=None) - pr.acknowledge_task_failure(task_id) + pr = PromptRepository(db, api_client) + pr.task_repository.acknowledge_task_failure(task_id) except (KeyError, RuntimeError): logger.exception("Failed to not acknowledge task.") raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED) @@ -265,7 +265,7 @@ def tasks_interaction( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, user=interaction.user) + pr = PromptRepository(db, api_client, client_user=interaction.user) match type(interaction): case protocol_schema.TextReplyToMessage: @@ -323,6 +323,6 @@ def close_collective_task( api_key: APIKey = Depends(deps.get_api_key), ): api_client = deps.api_auth(api_key, db) - pr = PromptRepository(db, api_client, user=None) - pr.close_task(close_task_request.message_id) + tr = TaskRepository(db, api_client) + tr.close_task(close_task_request.message_id) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index 03fd2cb4..c9afd88c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -25,7 +25,7 @@ def label_text( try: logger.info(f"Labeling text {text_labels=}.") - pr = PromptRepository(db, api_client, user=text_labels.user) + pr = PromptRepository(db, api_client, client_user=text_labels.user) pr.store_text_labels(text_labels) except Exception: diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 8d55bfec..5dda88eb 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -29,7 +29,7 @@ def query_user_messages( """ Query user messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( user_id=user_id, api_client_id=api_client_id, @@ -48,6 +48,6 @@ def query_user_messages( def mark_user_messages_deleted( user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) messages = pr.query_messages(user_id=user_id) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 386595e9..97ad34eb 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -6,27 +6,56 @@ import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg from sqlmodel import Field, Index, SQLModel -# The types of States a message tree can have. +class States(str, Enum): + """States of the Open-Assistant message tree state machine.""" + + INITIAL_PROMPT_REVIEW = "initial_prompt_review" + """In this state the message tree consists only of a single inital prompt root node. + Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or + `aborted_low_grade`.""" -class States(Enum): - INITIAL = "initial" BREEDING_PHASE = "breeding_phase" + """Assistant & prompter human demonstrations are collected. Concurrently labeling tasks + are handed out to check if the quality of the replies surpasses the minimum acceptable + quality. + When the required number of messages passing the initial labelling-quality check has been + collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses + are received the tree can also enter the `aborted_low_grade` state.""" + RANKING_PHASE = "ranking_phase" + """The tree has been successfully populated with the desired number of messages. Ranking + tasks are now handed out for all nodes with more than one child.""" + READY_FOR_SCORING = "ready_for_scoring" - CHILDREN_SCORED = "children_scored" - FINAL = "final" + """Required ranking responses have been collected and the scoring algorithm can now + compute the aggergated ranking scores that will appear in the dataset.""" + + READY_FOR_EXPORT = "ready_for_export" + """The Scoring algorithm computed rankings scores for all childern. The message tree can be + exported as part of an Open-Assistant message tree dataset.""" + + SCORING_FAILED = "scoring_failed" + """An exception occured in the scoring algorithm.""" + + ABORTED_LOW_GRADE = "aborted_low_grade" + """The system received too many bad reviews and stopped handing out tasks for this message tree.""" + + HALTED_BY_MODERATOR = "halted_by_moderator" + """A moderator decided to manually halt the message tree construction process.""" VALID_STATES = ( - States.INITIAL, + States.INITIAL_PROMPT_REVIEW, States.BREEDING_PHASE, States.RANKING_PHASE, States.READY_FOR_SCORING, - States.CHILDREN_SCORED, - States.FINAL, + States.READY_FOR_EXPORT, + States.ABORTED_LOW_GRADE, ) +TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR) + class MessageTreeState(SQLModel, table=True): __tablename__ = "message_tree_state" diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7c7dd7b6..7446ec07 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -8,98 +8,39 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User +from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id +from oasst_backend.user_repository import UserRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats +from oasst_shared.schemas.protocol import SystemStats from sqlalchemy import update from sqlmodel import Session, func from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND class PromptRepository: - def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + def __init__( + self, + db: Session, + api_client: ApiClient, + client_user: Optional[protocol_schema.User] = None, + user_repository: Optional[UserRepository] = None, + task_repository: Optional[TaskRepository] = None, + ): self.db = db self.api_client = api_client - self.user = self.lookup_user(user) + self.user_repository = user_repository or UserRepository(db, api_client) + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) self.user_id = self.user.id if self.user else None + self.task_repository = task_repository or TaskRepository( + db, api_client, client_user, user_repository=self.user_repository + ) self.journal = JournalWriter(db, api_client, self.user) - def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]: - if not client_user: - return None - user: User = ( - self.db.query(User) - .filter( - User.api_client_id == self.api_client.id, - User.username == client_user.id, - User.auth_method == client_user.auth_method, - ) - .first() - ) - if user is None: - # user is unknown, create new record - user = User( - username=client_user.id, - display_name=client_user.display_name, - api_client_id=self.api_client.id, - auth_method=client_user.auth_method, - ) - self.db.add(user) - self.db.commit() - self.db.refresh(user) - elif client_user.display_name and client_user.display_name != user.display_name: - # we found the user but the display name changed - user.display_name = client_user.display_name - self.db.add(user) - self.db.commit() - return user - - def validate_frontend_message_id(self, message_id: str) -> None: - # TODO: Should it be replaced with fastapi/pydantic validation? - if not isinstance(message_id, str): - raise OasstError( - f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID - ) - if not message_id: - raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) - - def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): - self.validate_frontend_message_id(frontend_message_id) - - # find task - task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() - if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) - if task.expired: - raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if task.done or task.ack is not None: - raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) - - task.frontend_message_id = frontend_message_id - task.ack = True - # ToDo: check race-condition, transaction - self.db.add(task) - self.db.commit() - - def acknowledge_task_failure(self, task_id): - # find task - task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() - if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) - if task.expired: - raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if task.done or task.ack is not None: - raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) - - task.ack = False - # ToDo: check race-condition, transaction - self.db.add(task) - self.db.commit() - def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message: - self.validate_frontend_message_id(frontend_message_id) + validate_frontend_message_id(frontend_message_id) message: Message = ( self.db.query(Message) .filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id) @@ -113,20 +54,48 @@ class PromptRepository: ) return message - def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: - self.validate_frontend_message_id(message_id) - task = ( - self.db.query(Task) - .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id) - .one_or_none() + def insert_message( + self, + *, + message_id: UUID, + frontend_message_id: str, + parent_id: UUID, + message_tree_id: UUID, + task_id: UUID, + role: str, + payload: db_payload.MessagePayload, + payload_type: str = None, + depth: int = 0, + ) -> Message: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + message = Message( + id=message_id, + parent_id=parent_id, + message_tree_id=message_tree_id, + task_id=task_id, + user_id=self.user_id, + role=role, + frontend_message_id=frontend_message_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + depth=depth, ) - return task + self.db.add(message) + self.db.commit() + self.db.refresh(message) + return message def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: - self.validate_frontend_message_id(frontend_message_id) - self.validate_frontend_message_id(user_frontend_message_id) + validate_frontend_message_id(frontend_message_id) + validate_frontend_message_id(user_frontend_message_id) - task = self.fetch_task_by_frontend_message_id(frontend_message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) if task is None: raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) @@ -174,7 +143,7 @@ class PromptRepository: def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) - task = self.fetch_task_by_frontend_message_id(rating.message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id) task_payload: db_payload.RateSummaryPayload = task.payload.payload if type(task_payload) != db_payload.RateSummaryPayload: raise OasstError( @@ -201,7 +170,7 @@ class PromptRepository: def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction: # fetch task - task = self.fetch_task_by_frontend_message_id(ranking.message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id) if not task.collective: task.done = True self.db.add(task) @@ -255,142 +224,6 @@ class PromptRepository: OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, ) - def store_task( - self, - task: protocol_schema.Task, - message_tree_id: UUID = None, - parent_message_id: UUID = None, - collective: bool = False, - ) -> Task: - payload: db_payload.TaskPayload - match type(task): - case protocol_schema.SummarizeStoryTask: - payload = db_payload.SummarizationStoryPayload(story=task.story) - - case protocol_schema.RateSummaryTask: - payload = db_payload.RateSummaryPayload( - full_text=task.full_text, summary=task.summary, scale=task.scale - ) - - case protocol_schema.InitialPromptTask: - payload = db_payload.InitialPromptPayload(hint=task.hint) - - case protocol_schema.PrompterReplyTask: - payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint) - - case protocol_schema.AssistantReplyTask: - payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) - - case protocol_schema.RankInitialPromptsTask: - payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) - - case protocol_schema.RankPrompterRepliesTask: - payload = db_payload.RankPrompterRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies - ) - - case protocol_schema.RankAssistantRepliesTask: - payload = db_payload.RankAssistantRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies - ) - - case protocol_schema.LabelInitialPromptTask: - payload = db_payload.LabelInitialPromptPayload( - type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels - ) - - case protocol_schema.LabelPrompterReplyTask: - payload = db_payload.LabelPrompterReplyPayload( - type=task.type, - message_id=task.message_id, - conversation=task.conversation, - reply=task.reply, - valid_labels=task.valid_labels, - ) - - case protocol_schema.LabelAssistantReplyTask: - payload = db_payload.LabelAssistantReplyPayload( - type=task.type, - message_id=task.message_id, - conversation=task.conversation, - reply=task.reply, - valid_labels=task.valid_labels, - ) - - case _: - raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) - - task = self.insert_task( - payload=payload, - id=task.id, - message_tree_id=message_tree_id, - parent_message_id=parent_message_id, - collective=collective, - ) - assert task.id == task.id - return task - - def insert_task( - self, - payload: db_payload.TaskPayload, - id: UUID = None, - message_tree_id: UUID = None, - parent_message_id: UUID = None, - collective: bool = False, - ) -> Task: - c = PayloadContainer(payload=payload) - task = Task( - id=id, - user_id=self.user_id, - payload_type=type(payload).__name__, - payload=c, - api_client_id=self.api_client.id, - message_tree_id=message_tree_id, - parent_message_id=parent_message_id, - collective=collective, - ) - self.db.add(task) - self.db.commit() - self.db.refresh(task) - return task - - def insert_message( - self, - *, - message_id: UUID, - frontend_message_id: str, - parent_id: UUID, - message_tree_id: UUID, - task_id: UUID, - role: str, - payload: db_payload.MessagePayload, - payload_type: str = None, - depth: int = 0, - ) -> Message: - if payload_type is None: - if payload is None: - payload_type = "null" - else: - payload_type = type(payload).__name__ - - message = Message( - id=message_id, - parent_id=parent_id, - message_tree_id=message_tree_id, - task_id=task_id, - user_id=self.user_id, - role=role, - frontend_message_id=frontend_message_id, - api_client_id=self.api_client.id, - payload_type=payload_type, - payload=PayloadContainer(payload=payload), - depth=depth, - ) - self.db.add(message) - self.db.commit() - self.db.refresh(message) - return message - def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) @@ -515,28 +348,6 @@ class PromptRepository: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) return message - def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): - """ - Mark task as done. No further messages will be accepted for this task. - """ - self.validate_frontend_message_id(frontend_message_id) - task = self.fetch_task_by_frontend_message_id(frontend_message_id) - - if not task: - raise OasstError( - f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND - ) - if task.expired: - raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED) - if not allow_personal_tasks and not task.collective: - raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE) - if task.done: - raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE) - - task.done = True - self.db.add(task) - self.db.commit() - @staticmethod def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """ @@ -728,24 +539,3 @@ class PromptRepository: deleted=result.get(True, 0), message_trees=result.get(None, 0), ) - - def get_user_leaderboard(self, role: str) -> LeaderboardStats: - """ - Get leaderboard stats for Messages created, - separate leaderboard for prompts & assistants - - """ - query = ( - self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id)) - .join(User, User.id == Message.user_id, isouter=True) - .filter(Message.deleted is not True, Message.role == role) - .group_by(Message.user_id, User.username, User.display_name) - .order_by(func.count(Message.user_id).desc()) - ) - - result = [ - {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]} - for i, j in enumerate(query.all(), start=1) - ] - - return LeaderboardStats(leaderboard=result) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py new file mode 100644 index 00000000..15484d66 --- /dev/null +++ b/backend/oasst_backend/task_repository.py @@ -0,0 +1,199 @@ +from typing import Optional +from uuid import UUID + +import oasst_backend.models.db_payload as db_payload +from oasst_backend.models import ApiClient, Task +from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.user_repository import UserRepository +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_404_NOT_FOUND + + +def validate_frontend_message_id(message_id: str) -> None: + # TODO: Should it be replaced with fastapi/pydantic validation? + if not isinstance(message_id, str): + raise OasstError( + f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID + ) + if not message_id: + raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) + + +class TaskRepository: + def __init__( + self, + db: Session, + api_client: ApiClient, + client_user: Optional[protocol_schema.User], + user_repository: UserRepository, + ): + self.db = db + self.api_client = api_client + self.user_repository = user_repository + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) + self.user_id = self.user.id if self.user else None + + def store_task( + self, + task: protocol_schema.Task, + message_tree_id: UUID = None, + parent_message_id: UUID = None, + collective: bool = False, + ) -> Task: + payload: db_payload.TaskPayload + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = db_payload.SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = db_payload.RateSummaryPayload( + full_text=task.full_text, summary=task.summary, scale=task.scale + ) + + case protocol_schema.InitialPromptTask: + payload = db_payload.InitialPromptPayload(hint=task.hint) + + case protocol_schema.PrompterReplyTask: + payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case protocol_schema.RankInitialPromptsTask: + payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) + + case protocol_schema.RankPrompterRepliesTask: + payload = db_payload.RankPrompterRepliesPayload( + type=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.RankAssistantRepliesTask: + payload = db_payload.RankAssistantRepliesPayload( + type=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.LabelInitialPromptTask: + payload = db_payload.LabelInitialPromptPayload( + type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels + ) + + case protocol_schema.LabelPrompterReplyTask: + payload = db_payload.LabelPrompterReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, + ) + + case protocol_schema.LabelAssistantReplyTask: + payload = db_payload.LabelAssistantReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, + ) + + case _: + raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) + + task = self.insert_task( + payload=payload, + id=task.id, + message_tree_id=message_tree_id, + parent_message_id=parent_message_id, + collective=collective, + ) + assert task.id == task.id + return task + + def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): + validate_frontend_message_id(frontend_message_id) + + # find task + task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() + if task is None: + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) + if task.expired: + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) + if task.done or task.ack is not None: + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) + + task.frontend_message_id = frontend_message_id + task.ack = True + # ToDo: check race-condition, transaction + self.db.add(task) + self.db.commit() + + def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): + """ + Mark task as done. No further messages will be accepted for this task. + """ + validate_frontend_message_id(frontend_message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) + + if not task: + raise OasstError( + f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND + ) + if task.expired: + raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED) + if not allow_personal_tasks and not task.collective: + raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE) + if task.done: + raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE) + + task.done = True + self.db.add(task) + self.db.commit() + + def acknowledge_task_failure(self, task_id): + # find task + task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() + if task is None: + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) + if task.expired: + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) + if task.done or task.ack is not None: + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) + + task.ack = False + # ToDo: check race-condition, transaction + self.db.add(task) + self.db.commit() + + def insert_task( + self, + payload: db_payload.TaskPayload, + id: UUID = None, + message_tree_id: UUID = None, + parent_message_id: UUID = None, + collective: bool = False, + ) -> Task: + c = PayloadContainer(payload=payload) + task = Task( + id=id, + user_id=self.user_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + message_tree_id=message_tree_id, + parent_message_id=parent_message_id, + collective=collective, + ) + self.db.add(task) + self.db.commit() + self.db.refresh(task) + return task + + def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: + validate_frontend_message_id(message_id) + task = ( + self.db.query(Task) + .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id) + .one_or_none() + ) + return task diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py new file mode 100644 index 00000000..b5508899 --- /dev/null +++ b/backend/oasst_backend/user_repository.py @@ -0,0 +1,64 @@ +from typing import Optional + +from oasst_backend.models import ApiClient, Message, User +from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import LeaderboardStats +from sqlmodel import Session, func + + +class UserRepository: + def __init__(self, db: Session, api_client: ApiClient): + self.db = db + self.api_client = api_client + + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + if not client_user: + return None + user: User = ( + self.db.query(User) + .filter( + User.api_client_id == self.api_client.id, + User.username == client_user.id, + User.auth_method == client_user.auth_method, + ) + .first() + ) + if user is None: + if create_missing: + # user is unknown, create new record + user = User( + username=client_user.id, + display_name=client_user.display_name, + api_client_id=self.api_client.id, + auth_method=client_user.auth_method, + ) + self.db.add(user) + self.db.commit() + self.db.refresh(user) + elif client_user.display_name and client_user.display_name != user.display_name: + # we found the user but the display name changed + user.display_name = client_user.display_name + self.db.add(user) + self.db.commit() + return user + + def get_user_leaderboard(self, role: str) -> LeaderboardStats: + """ + Get leaderboard stats for Messages created, + separate leaderboard for prompts & assistants + + """ + query = ( + self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id)) + .join(User, User.id == Message.user_id, isouter=True) + .filter(Message.deleted is not True, Message.role == role) + .group_by(Message.user_id, User.username, User.display_name) + .order_by(func.count(Message.user_id).desc()) + ) + + result = [ + {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]} + for i, j in enumerate(query.all(), start=1) + ] + + return LeaderboardStats(leaderboard=result)