mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
Extract UserRepository and TaskRepository from PromptRepository
* Extract classes UserRepository and TaskRepository from PromptRepository * move close_task() to TaskRepository and get_user_leaderboard to UserRepository() * Use UserRepository in leaderboards endpoint, add type annotation to leaderboards endpoint
This commit is contained in:
+11
-6
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
for cmsg in conversation_messages
|
||||
]
|
||||
)
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return utils.prepare_message_list(messages)
|
||||
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
@@ -11,15 +12,15 @@ router = APIRouter()
|
||||
def get_assistant_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="assistant")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="prompter")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="prompter")
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -51,7 +51,7 @@ def get_message(
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -64,7 +64,7 @@ def get_conv(
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
@@ -76,7 +76,7 @@ def get_tree(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -89,7 +89,7 @@ def get_children(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
@@ -101,7 +101,7 @@ def get_descendants(
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -114,7 +114,7 @@ def get_longest_conv(
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -127,7 +127,7 @@ def get_max_children(
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -137,5 +137,5 @@ def get_max_children(
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
@@ -13,5 +13,5 @@ def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
return pr.get_stats()
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -190,9 +190,9 @@ def request_task(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -217,11 +217,11 @@ def tasks_acknowledge(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -245,8 +245,8 @@ def tasks_acknowledge_failure(
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.acknowledge_task_failure(task_id)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
@@ -265,7 +265,7 @@ def tasks_interaction(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
@@ -323,6 +323,6 @@ def close_collective_task(
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.message_id)
|
||||
tr = TaskRepository(db, api_client)
|
||||
tr.close_task(close_task_request.message_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -25,7 +25,7 @@ def label_text(
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, user=text_labels.user)
|
||||
pr = PromptRepository(db, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_user_messages(
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
@@ -48,6 +48,6 @@ def query_user_messages(
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -6,27 +6,56 @@ import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
# The types of States a message tree can have.
|
||||
|
||||
class States(str, Enum):
|
||||
"""States of the Open-Assistant message tree state machine."""
|
||||
|
||||
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
||||
"""In this state the message tree consists only of a single inital prompt root node.
|
||||
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
|
||||
`aborted_low_grade`."""
|
||||
|
||||
class States(Enum):
|
||||
INITIAL = "initial"
|
||||
BREEDING_PHASE = "breeding_phase"
|
||||
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
|
||||
are handed out to check if the quality of the replies surpasses the minimum acceptable
|
||||
quality.
|
||||
When the required number of messages passing the initial labelling-quality check has been
|
||||
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
|
||||
are received the tree can also enter the `aborted_low_grade` state."""
|
||||
|
||||
RANKING_PHASE = "ranking_phase"
|
||||
"""The tree has been successfully populated with the desired number of messages. Ranking
|
||||
tasks are now handed out for all nodes with more than one child."""
|
||||
|
||||
READY_FOR_SCORING = "ready_for_scoring"
|
||||
CHILDREN_SCORED = "children_scored"
|
||||
FINAL = "final"
|
||||
"""Required ranking responses have been collected and the scoring algorithm can now
|
||||
compute the aggergated ranking scores that will appear in the dataset."""
|
||||
|
||||
READY_FOR_EXPORT = "ready_for_export"
|
||||
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
|
||||
exported as part of an Open-Assistant message tree dataset."""
|
||||
|
||||
SCORING_FAILED = "scoring_failed"
|
||||
"""An exception occured in the scoring algorithm."""
|
||||
|
||||
ABORTED_LOW_GRADE = "aborted_low_grade"
|
||||
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
|
||||
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
States.INITIAL,
|
||||
States.INITIAL_PROMPT_REVIEW,
|
||||
States.BREEDING_PHASE,
|
||||
States.RANKING_PHASE,
|
||||
States.READY_FOR_SCORING,
|
||||
States.CHILDREN_SCORED,
|
||||
States.FINAL,
|
||||
States.READY_FOR_EXPORT,
|
||||
States.ABORTED_LOW_GRADE,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
|
||||
|
||||
|
||||
class MessageTreeState(SQLModel, table=True):
|
||||
__tablename__ = "message_tree_state"
|
||||
|
||||
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User] = None,
|
||||
user_repository: Optional[UserRepository] = None,
|
||||
task_repository: Optional[TaskRepository] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user = self.lookup_user(user)
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
)
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
|
||||
@@ -113,20 +54,48 @@ class PromptRepository:
|
||||
)
|
||||
return message
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
self.validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
return task
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
@@ -174,7 +143,7 @@ class PromptRepository:
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
@@ -201,7 +170,7 @@ class PromptRepository:
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch task
|
||||
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -255,142 +224,6 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
@@ -515,28 +348,6 @@ class PromptRepository:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
@@ -728,24 +539,3 @@ class PromptRepository:
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def validate_frontend_message_id(message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
|
||||
class TaskRepository:
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User],
|
||||
user_repository: UserRepository,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return task
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
Reference in New Issue
Block a user