Extract UserRepository and TaskRepository from PromptRepository

* Extract classes UserRepository and TaskRepository from PromptRepository
* move close_task() to TaskRepository and get_user_leaderboard to UserRepository()
* Use UserRepository in leaderboards endpoint, add type annotation to leaderboards endpoint
This commit is contained in:
Andreas Köpf
2023-01-08 19:08:47 +01:00
committed by GitHub
parent 10b9d4608a
commit 8906854dbf
13 changed files with 409 additions and 321 deletions
+11 -6
View File
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
ur = UserRepository(db=db, api_client=api_client)
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
for cmsg in conversation_messages
]
)
task = pr.store_task(
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
return utils.prepare_message(message)
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return utils.prepare_message_list(messages)
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
+8 -7
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@@ -11,15 +12,15 @@ router = APIRouter()
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="assistant")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="prompter")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="prompter")
+9 -9
View File
@@ -29,7 +29,7 @@ def query_messages(
"""
Query messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -51,7 +51,7 @@ def get_message(
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@@ -64,7 +64,7 @@ def get_conv(
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@@ -76,7 +76,7 @@ def get_tree(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -89,7 +89,7 @@ def get_children(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@@ -101,7 +101,7 @@ def get_descendants(
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -114,7 +114,7 @@ def get_longest_conv(
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -127,7 +127,7 @@ def get_max_children(
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -137,5 +137,5 @@ def get_max_children(
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)
+1 -1
View File
@@ -13,5 +13,5 @@ def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
return pr.get_stats()
+10 -10
View File
@@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -190,9 +190,9 @@ def request_task(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, request.user)
pr = PromptRepository(db, api_client, client_user=request.user)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -217,11 +217,11 @@ def tasks_acknowledge(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
@@ -245,8 +245,8 @@ def tasks_acknowledge_failure(
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.acknowledge_task_failure(task_id)
pr = PromptRepository(db, api_client)
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -265,7 +265,7 @@ def tasks_interaction(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=interaction.user)
pr = PromptRepository(db, api_client, client_user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToMessage:
@@ -323,6 +323,6 @@ def close_collective_task(
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.message_id)
tr = TaskRepository(db, api_client)
tr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
+1 -1
View File
@@ -25,7 +25,7 @@ def label_text(
try:
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, user=text_labels.user)
pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
+2 -2
View File
@@ -29,7 +29,7 @@ def query_user_messages(
"""
Query user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
@@ -48,6 +48,6 @@ def query_user_messages(
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
@@ -6,27 +6,56 @@ import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
# The types of States a message tree can have.
class States(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single inital prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
`aborted_low_grade`."""
class States(Enum):
INITIAL = "initial"
BREEDING_PHASE = "breeding_phase"
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
are handed out to check if the quality of the replies surpasses the minimum acceptable
quality.
When the required number of messages passing the initial labelling-quality check has been
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
are received the tree can also enter the `aborted_low_grade` state."""
RANKING_PHASE = "ranking_phase"
"""The tree has been successfully populated with the desired number of messages. Ranking
tasks are now handed out for all nodes with more than one child."""
READY_FOR_SCORING = "ready_for_scoring"
CHILDREN_SCORED = "children_scored"
FINAL = "final"
"""Required ranking responses have been collected and the scoring algorithm can now
compute the aggergated ranking scores that will appear in the dataset."""
READY_FOR_EXPORT = "ready_for_export"
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
exported as part of an Open-Assistant message tree dataset."""
SCORING_FAILED = "scoring_failed"
"""An exception occured in the scoring algorithm."""
ABORTED_LOW_GRADE = "aborted_low_grade"
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
VALID_STATES = (
States.INITIAL,
States.INITIAL_PROMPT_REVIEW,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
States.CHILDREN_SCORED,
States.FINAL,
States.READY_FOR_EXPORT,
States.ABORTED_LOW_GRADE,
)
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
+58 -268
View File
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User] = None,
user_repository: Optional[UserRepository] = None,
task_repository: Optional[TaskRepository] = None,
):
self.db = db
self.api_client = api_client
self.user = self.lookup_user(user)
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
)
self.journal = JournalWriter(db, api_client, self.user)
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def validate_frontend_message_id(self, message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
self.validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
self.validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
@@ -113,20 +54,48 @@ class PromptRepository:
)
return message
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
self.validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
return task
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
@@ -174,7 +143,7 @@ class PromptRepository:
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
task = self.fetch_task_by_frontend_message_id(rating.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
@@ -201,7 +170,7 @@ class PromptRepository:
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -255,142 +224,6 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -515,28 +348,6 @@ class PromptRepository:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
@@ -728,24 +539,3 @@ class PromptRepository:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
+199
View File
@@ -0,0 +1,199 @@
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_404_NOT_FOUND
def validate_frontend_message_id(message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
class TaskRepository:
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User],
user_repository: UserRepository,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
validate_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return task
+64
View File
@@ -0,0 +1,64 @@
from typing import Optional
from oasst_backend.models import ApiClient, Message, User
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
class UserRepository:
def __init__(self, db: Session, api_client: ApiClient):
self.db = db
self.api_client = api_client
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)