diff --git a/backend/alembic/versions/2023_01_15_0002-7c98102efbca_change_user_stats_ranking_counts.py b/backend/alembic/versions/2023_01_15_0002-7c98102efbca_change_user_stats_ranking_counts.py new file mode 100644 index 00000000..163cfa8f --- /dev/null +++ b/backend/alembic/versions/2023_01_15_0002-7c98102efbca_change_user_stats_ranking_counts.py @@ -0,0 +1,83 @@ +"""change user_stats ranking counts + +Revision ID: 7c98102efbca +Revises: 619255ae9076 +Create Date: 2023-01-15 00:02:45.622986 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision = "7c98102efbca" +down_revision = "619255ae9076" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_stats") + op.create_table( + "user_stats", + sa.Column("user_id", UUID(as_uuid=True), nullable=False), + sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("base_date", sa.DateTime(), nullable=True), + sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("leader_score", sa.Integer(), nullable=False), + sa.Column("prompts", sa.Integer(), nullable=False), + sa.Column("replies_assistant", sa.Integer(), nullable=False), + sa.Column("replies_prompter", sa.Integer(), nullable=False), + sa.Column("labels_simple", sa.Integer(), nullable=False), + sa.Column("labels_full", sa.Integer(), nullable=False), + sa.Column("rankings_total", sa.Integer(), nullable=False), + sa.Column("rankings_good", sa.Integer(), nullable=False), + sa.Column("accepted_prompts", sa.Integer(), nullable=False), + sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False), + sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False), + sa.Column("reply_ranked_1", sa.Integer(), nullable=False), + sa.Column("reply_ranked_2", sa.Integer(), nullable=False), + sa.Column("reply_ranked_3", sa.Integer(), nullable=False), + sa.Column("streak_last_day_date", sa.DateTime(), nullable=True), + sa.Column("streak_days", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("user_id", "time_frame"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "user_stats", + sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.add_column( + "user_stats", + sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.add_column( + "user_stats", + sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.add_column( + "user_stats", + sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.add_column( + "user_stats", + sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.add_column( + "user_stats", + sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False), + ) + op.drop_column("user_stats", "reply_ranked_3") + op.drop_column("user_stats", "reply_ranked_2") + op.drop_column("user_stats", "reply_ranked_1") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_01_15_1139-423557e869e4_add_indices_for_created_date.py b/backend/alembic/versions/2023_01_15_1139-423557e869e4_add_indices_for_created_date.py new file mode 100644 index 00000000..ae03b3df --- /dev/null +++ b/backend/alembic/versions/2023_01_15_1139-423557e869e4_add_indices_for_created_date.py @@ -0,0 +1,30 @@ +"""add indices for created_date + +Revision ID: 423557e869e4 +Revises: 7c98102efbca +Create Date: 2023-01-15 11:39:10.407859 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "423557e869e4" +down_revision = "7c98102efbca" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False) + op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False) + op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels") + op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction") + op.drop_index(op.f("ix_message_created_date"), table_name="message") + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index cf0b1f76..fb7aa26b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -9,6 +9,7 @@ import alembic.config import fastapi import redis.asyncio as redis from fastapi_limiter import FastAPILimiter +from fastapi_utils.tasks import repeat_every from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router @@ -18,6 +19,7 @@ from oasst_backend.database import engine from oasst_backend.models import message_tree_state from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository from oasst_backend.tree_manager import TreeManager +from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -195,6 +197,50 @@ def ensure_tree_states(): logger.exception("TreeManager.ensure_tree_states() failed.") +@app.on_event("startup") +@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False) +def update_leader_board_day() -> None: + try: + with Session(engine) as session: + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.day) + except Exception: + logger.exception("Error during leaderboard update (daily)") + + +@app.on_event("startup") +@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False) +def update_leader_board_week() -> None: + try: + with Session(engine) as session: + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.week) + except Exception: + logger.exception("Error during user states update (weekly)") + + +@app.on_event("startup") +@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False) +def update_leader_board_month() -> None: + try: + with Session(engine) as session: + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.month) + except Exception: + logger.exception("Error during user states update (monthly)") + + +@app.on_event("startup") +@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False) +def update_leader_board_total() -> None: + try: + with Session(engine) as session: + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.total) + except Exception: + logger.exception("Error during user states update (total)") + + app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 5bdf1c97..0c68b5c9 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -19,5 +19,5 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages", api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) -api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"]) +api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"]) api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 46aea637..0a6e5660 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -1,26 +1,21 @@ -from fastapi import APIRouter, Depends +from typing import Optional + +from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.models import ApiClient -from oasst_backend.user_repository import UserRepository +from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session router = APIRouter() -@router.get("/create/assistant") -def get_assistant_leaderboard( +@router.get("/{time_frame}") +def get_leaderboard_day( + time_frame: UserStatsTimeFrame, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), - api_client: ApiClient = Depends(deps.get_trusted_api_client), ) -> LeaderboardStats: - ur = UserRepository(db, api_client) - return ur.get_user_leaderboard(role="assistant") - - -@router.get("/create/prompter") -def get_prompter_leaderboard( - db: Session = Depends(deps.get_db), - api_client: ApiClient = Depends(deps.get_trusted_api_client), -) -> LeaderboardStats: - ur = UserRepository(db, api_client) - return ur.get_user_leaderboard(role="prompter") + usr = UserStatsRepository(db) + return usr.get_leader_board(time_frame, limit=max_count) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 53277d2b..1b182a87 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -109,6 +109,22 @@ class Settings(BaseSettings): tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration() + USER_STATS_INTERVAL_DAY: int = 15 # minutes + USER_STATS_INTERVAL_WEEK: int = 60 # minutes + USER_STATS_INTERVAL_MONTH: int = 120 # minutes + USER_STATS_INTERVAL_TOTAL: int = 240 # minutes + + @validator( + "USER_STATS_INTERVAL_DAY", + "USER_STATS_INTERVAL_WEEK", + "USER_STATS_INTERVAL_MONTH", + "USER_STATS_INTERVAL_TOTAL", + ) + def validate_user_stats_intervals(cls, v: int): + if v < 1: + raise ValueError(v) + return v + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 9dc052d7..2b30b475 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -8,12 +8,13 @@ from .message_tree_state import MessageTreeState from .task import Task from .text_labels import TextLabels from .user import User -from .user_stats import UserStats +from .user_stats import UserStats, UserStatsTimeFrame __all__ = [ "ApiClient", "User", "UserStats", + "UserStatsTimeFrame", "Message", "MessageEmbedding", "MessageReaction", diff --git a/backend/oasst_backend/models/db_payload.py b/backend/oasst_backend/models/db_payload.py index 590e9f5b..ddaf2391 100644 --- a/backend/oasst_backend/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload): type: Literal["message_ranking"] = "message_ranking" ranking: list[int] ranked_message_ids: list[UUID] + ranking_parent_id: Optional[UUID] + message_tree_id: Optional[UUID] @payload_type class RankConversationRepliesPayload(TaskPayload): conversation: protocol_schema.Conversation # the conversation so far reply_messages: list[protocol_schema.ConversationMessage] + ranking_parent_id: Optional[UUID] + message_tree_id: Optional[UUID] @payload_type @@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload): prompt: str valid_labels: list[str] mandatory_labels: Optional[list[str]] + mode: Optional[protocol_schema.LabelTaskMode] @payload_type @@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload): reply: str valid_labels: list[str] mandatory_labels: Optional[list[str]] + mode: Optional[protocol_schema.LabelTaskMode] @payload_type diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 9583510f..7c8b9f13 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -30,7 +30,7 @@ class Message(SQLModel, table=True): api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") frontend_message_id: str = Field(max_length=200, nullable=False) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True) ) payload_type: str = Field(nullable=False, max_length=200) payload: Optional[PayloadContainer] = Field( diff --git a/backend/oasst_backend/models/message_reaction.py b/backend/oasst_backend/models/message_reaction.py index 3aaa774c..74e21a61 100644 --- a/backend/oasst_backend/models/message_reaction.py +++ b/backend/oasst_backend/models/message_reaction.py @@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True): sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True) ) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True) ) payload_type: str = Field(nullable=False, max_length=200) payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index d923be97..5d0d7e73 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -4,6 +4,7 @@ from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg +from oasst_shared.utils import utcnow from sqlalchemy import false from sqlmodel import Field, SQLModel @@ -35,4 +36,4 @@ class Task(SQLModel, table=True): @property def expired(self) -> bool: - return self.expiry_date is not None and datetime.utcnow() > self.expiry_date + return self.expiry_date is not None and utcnow() > self.expiry_date diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py index e6878a87..34831d6b 100644 --- a/backend/oasst_backend/models/text_labels.py +++ b/backend/oasst_backend/models/text_labels.py @@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True): ) user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False)) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True), ) api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") text: str = Field(nullable=False, max_length=2**16) diff --git a/backend/oasst_backend/models/user_stats.py b/backend/oasst_backend/models/user_stats.py index 0bdaa3f4..393e3851 100644 --- a/backend/oasst_backend/models/user_stats.py +++ b/backend/oasst_backend/models/user_stats.py @@ -22,6 +22,7 @@ class UserStats(SQLModel, table=True): sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True) ) time_frame: Optional[str] = Field(nullable=False, primary_key=True) + base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) leader_score: int = 0 modified_date: Optional[datetime] = Field( @@ -40,14 +41,27 @@ class UserStats(SQLModel, table=True): accepted_replies_assistant: int = 0 accepted_replies_prompter: int = 0 - reply_assistant_ranked_1: int = 0 - reply_assistant_ranked_2: int = 0 - reply_assistant_ranked_3: int = 0 - - reply_prompter_ranked_1: int = 0 - reply_prompter_ranked_2: int = 0 - reply_prompter_ranked_3: int = 0 + reply_ranked_1: int = 0 + reply_ranked_2: int = 0 + reply_ranked_3: int = 0 # only used for time span "total" streak_last_day_date: Optional[datetime] = Field(nullable=True) streak_days: Optional[int] = Field(nullable=True) + + def compute_leader_score(self) -> int: + return ( + self.prompts + + self.replies_assistant * 4 + + self.replies_prompter + + self.labels_simple + + self.labels_full * 2 + + self.rankings_total + + self.rankings_good + + self.accepted_prompts + + self.accepted_replies_assistant * 4 + + self.accepted_replies_prompter + + self.reply_ranked_1 * 9 + + self.reply_ranked_2 * 3 + + self.reply_ranked_3 + ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 0c40daa0..08557ba8 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -260,7 +260,10 @@ class PromptRepository: self.db.add(message) reaction_payload = db_payload.RankingReactionPayload( - ranking=ranking.ranking, ranked_message_ids=ranked_message_ids + ranking=ranking.ranking, + ranked_message_ids=ranked_message_ids, + ranking_parent_id=task_payload.ranking_parent_id, + message_tree_id=task_payload.message_tree_id, ) reaction = self.insert_reaction(task.id, reaction_payload) self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index de7eb28a..acf48182 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -67,17 +67,30 @@ class TaskRepository: case protocol_schema.RankPrompterRepliesTask: payload = db_payload.RankPrompterRepliesPayload( - type=task.type, conversation=task.conversation, reply_messages=task.reply_messages + type=task.type, + conversation=task.conversation, + reply_messages=task.reply_messages, + ranking_parent_id=task.ranking_parent_id, + message_tree_id=task.message_tree_id, ) case protocol_schema.RankAssistantRepliesTask: payload = db_payload.RankAssistantRepliesPayload( - type=task.type, conversation=task.conversation, reply_messages=task.reply_messages + type=task.type, + conversation=task.conversation, + reply_messages=task.reply_messages, + ranking_parent_id=task.ranking_parent_id, + message_tree_id=task.message_tree_id, ) case protocol_schema.LabelInitialPromptTask: payload = db_payload.LabelInitialPromptPayload( - type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels + type=task.type, + message_id=task.message_id, + prompt=task.prompt, + valid_labels=task.valid_labels, + mandatory_labels=task.mandatory_labels, + mode=task.mode, ) case protocol_schema.LabelPrompterReplyTask: @@ -88,6 +101,7 @@ class TaskRepository: reply=task.reply, valid_labels=task.valid_labels, mandatory_labels=task.mandatory_labels, + mode=task.mode, ) case protocol_schema.LabelAssistantReplyTask: @@ -98,6 +112,7 @@ class TaskRepository: reply=task.reply, valid_labels=task.valid_labels, mandatory_labels=task.mandatory_labels, + mode=task.mode, ) case _: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index fe45ca0e..a9a282e2 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -207,6 +207,9 @@ class TreeManager: ranking_parent_id = random.choice(incomplete_rankings).parent_id messages = self.pr.fetch_message_conversation(ranking_parent_id) + assert len(messages) > 1 and messages[-1].id == ranking_parent_id + ranking_parent = messages[-1] + assert not ranking_parent.deleted and ranking_parent.review_result conversation = prepare_conversation(messages) replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) @@ -218,12 +221,20 @@ class TreeManager: if messages[-1].role == "assistant": logger.info("Generating a RankPrompterRepliesTask.") task = protocol_schema.RankPrompterRepliesTask( - conversation=conversation, replies=replies, reply_messages=reply_messages + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, ) else: logger.info("Generating a RankAssistantRepliesTask.") task = protocol_schema.RankAssistantRepliesTask( - conversation=conversation, replies=replies, reply_messages=reply_messages + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, ) parent_message_id = ranking_parent_id diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 26de963f..b86778aa 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -1,11 +1,10 @@ from typing import Optional from uuid import UUID -from oasst_backend.models import ApiClient, Message, User +from oasst_backend.models import ApiClient, User from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import LeaderboardStats -from sqlmodel import Session, func +from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -137,27 +136,6 @@ class UserRepository: self.db.commit() return user - def get_user_leaderboard(self, role: str) -> LeaderboardStats: - """ - Get leaderboard stats for Messages created, - separate leaderboard for prompts & assistants - - """ - query = ( - self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id)) - .join(User, User.id == Message.user_id, isouter=True) - .filter(Message.deleted is not True, Message.role == role) - .group_by(Message.user_id, User.username, User.display_name) - .order_by(func.count(Message.user_id).desc()) - ) - - result = [ - {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]} - for i, j in enumerate(query.all(), start=1) - ] - - return LeaderboardStats(leaderboard=result) - def query_users( self, api_client_id: Optional[UUID] = None, diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py new file mode 100644 index 00000000..8f047ab3 --- /dev/null +++ b/backend/oasst_backend/user_stats_repository.py @@ -0,0 +1,240 @@ +from datetime import datetime, timedelta +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame +from oasst_backend.models.db_payload import ( + LabelAssistantReplyPayload, + LabelPrompterReplyPayload, + RankingReactionPayload, +) +from oasst_shared.schemas.protocol import LeaderboardStats, UserScore +from oasst_shared.utils import log_timing, utcnow +from sqlalchemy.dialects import postgresql +from sqlmodel import Session, delete, func + + +class UserStatsRepository: + def __init__(self, session: Session): + self.session = session + + def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats: + """ + Get leaderboard stats for the specified time frame + """ + + qry = ( + self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats) + .join(UserStats, User.id == UserStats.user_id) + .filter(UserStats.time_frame == time_frame.value) + .order_by(UserStats.leader_score.desc()) + .limit(limit) + ) + + def create_user_score(user_rank: int, r): + d = r["UserStats"].dict() + for k in ["user_id", "username", "auth_method", "display_name"]: + d[k] = r[k] + return UserScore(user_rank=user_rank, **d) + + leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))] + return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard) + + def query_total_prompts_per_user( + self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True + ): + qry = self.session.query(Message.user_id, func.count()).filter( + Message.deleted == sa.false(), Message.parent_id.is_(None) + ) + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + if only_reviewed: + qry = qry.filter(Message.review_result == sa.true()) + qry = qry.group_by(Message.user_id) + return qry + + def query_replies_by_role_per_user( + self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True + ) -> list: + qry = self.session.query(Message.user_id, Message.role, func.count()).filter( + Message.deleted == sa.false(), Message.parent_id.is_not(None) + ) + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + if only_reviewed: + qry = qry.filter(Message.review_result == sa.true()) + qry = qry.group_by(Message.user_id, Message.role) + return qry + + def query_labels_by_mode_per_user( + self, payload_type: str = LabelAssistantReplyPayload.__name__, reference_time: Optional[datetime] = None + ): + qry = self.session.query(Task.user_id, Task.payload["payload", "mode"].astext, func.count()).filter( + Task.done == sa.true(), Task.payload_type == payload_type + ) + if reference_time: + qry = qry.filter(Task.created_date >= reference_time) + qry = qry.group_by(Task.user_id, Task.payload["payload", "mode"].astext) + return qry + + def query_rankings_per_user(self, reference_time: Optional[datetime] = None): + qry = self.session.query(MessageReaction.user_id, func.count()).filter( + MessageReaction.payload_type == RankingReactionPayload.__name__ + ) + if reference_time: + qry = qry.filter(MessageReaction.created_date >= reference_time) + qry = qry.group_by(MessageReaction.user_id) + return qry + + def query_ranking_result_users(self, rank: int = 0, reference_time: Optional[datetime] = None): + ranked_message_id = MessageReaction.payload["payload", "ranked_message_ids", rank].astext.cast( + postgresql.UUID(as_uuid=True) + ) + qry = ( + self.session.query(Message.user_id, func.count()) + .select_from(MessageReaction) + .join(Message, ranked_message_id == Message.id) + .filter(MessageReaction.payload_type == RankingReactionPayload.__name__) + ) + if reference_time: + qry = qry.filter(MessageReaction.created_date >= reference_time) + qry = qry.group_by(Message.user_id) + return qry + + def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None): + # gather user data + + stats_by_user: dict[UUID, UserStats] = dict() + now = utcnow() + + def get_stats(id: UUID) -> UserStats: + us = stats_by_user.get(id) + if not us: + us = UserStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date) + stats_by_user[id] = us + return us + + # total prompts + qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=False) + for r in qry: + uid, count = r + get_stats(uid).prompts = count + + # accepted prompts + qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=True) + for r in qry: + uid, count = r + get_stats(uid).accepted_prompts = count + + # total replies + qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=False) + for r in qry: + uid, role, count = r + s = get_stats(uid) + if role == "assistant": + s.replies_assistant += count + elif role == "prompter": + s.replies_prompter += count + + # accepted replies + qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=True) + for r in qry: + uid, role, count = r + s = get_stats(uid) + if role == "assistant": + s.accepted_replies_assistant += count + elif role == "prompter": + s.accepted_replies_prompter += count + + # simple and full labels + qry = self.query_labels_by_mode_per_user( + payload_type=LabelAssistantReplyPayload.__name__, reference_time=base_date + ) + for r in qry: + uid, mode, count = r + s = get_stats(uid) + if mode == "simple": + s.labels_simple = count + elif mode == "full": + s.labels_full = count + + qry = self.query_labels_by_mode_per_user( + payload_type=LabelPrompterReplyPayload.__name__, reference_time=base_date + ) + for r in qry: + uid, mode, count = r + s = get_stats(uid) + if mode == "simple": + s.labels_simple += count + elif mode == "full": + s.labels_full += count + + qry = self.query_rankings_per_user(reference_time=base_date) + for r in qry: + uid, count = r + get_stats(uid).rankings_total = count + + rank_field_names = ["reply_ranked_1", "reply_ranked_2", "reply_ranked_3"] + for i, fn in enumerate(rank_field_names): + qry = self.query_ranking_result_users(reference_time=base_date, rank=0) + for r in qry: + uid, count = r + setattr(get_stats(uid), fn, count) + + # delete all existing stast for time frame + d = delete(UserStats).where(UserStats.time_frame == time_frame_key) + self.session.execute(d) + + # compute magic leader score + for v in stats_by_user.values(): + v.leader_score = v.compute_leader_score() + + # insert user objects + self.session.add_all(stats_by_user.values()) + self.session.flush() + + def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None): + self._update_stats_internal(time_frame_key, reference_time) + self.session.commit() + + @log_timing(log_kwargs=True, level="INFO") + def update_stats(self, *, time_frame: UserStatsTimeFrame): + now = utcnow() + match time_frame: + case UserStatsTimeFrame.day: + r = now - timedelta(days=1) + self.update_stats_time_frame(time_frame.value, r) + + case UserStatsTimeFrame.week: + r = now.date() - timedelta(days=7) + r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo) + self.update_stats_time_frame(time_frame.value, r) + + case UserStatsTimeFrame.month: + r = now.date() - timedelta(days=30) + r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo) + self.update_stats_time_frame(time_frame.value, r) + + case UserStatsTimeFrame.total: + self.update_stats_time_frame(time_frame.value, None) + + @log_timing(level="INFO") + def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]): + for t in time_frames: + self.update_stats(time_frame=t) + + @log_timing(level="INFO") + def update_all_time_frames(self): + self.update_multiple_time_frames(list(UserStatsTimeFrame)) + + +if __name__ == "__main__": + from oasst_backend.api.deps import get_dummy_api_client + from oasst_backend.database import engine + + with Session(engine) as session: + api_client = get_dummy_api_client(session) + usr = UserStatsRepository(session) + usr.update_all_time_frames() + usr.get_leader_board(UserStatsTimeFrame.total) diff --git a/backend/requirements.txt b/backend/requirements.txt index fedf8ee3..0f91315e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,7 @@ alembic==1.8.1 fastapi==0.88.0 fastapi-limiter==0.1.5 +fastapi-utils==0.2.1 loguru==0.6.0 numpy==1.22.4 psycopg2-binary==2.9.5 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 839ac944..23b19e71 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -169,6 +169,8 @@ class RankConversationRepliesTask(Task): conversation: Conversation # the conversation so far replies: list[str] # deprecated, use reply_messages reply_messages: list[ConversationMessage] + message_tree_id: UUID + ranking_parent_id: UUID class RankPrompterRepliesTask(RankConversationRepliesTask): @@ -356,14 +358,40 @@ class SystemStats(BaseModel): class UserScore(BaseModel): - ranking: int + user_rank: int user_id: UUID username: str + auth_method: str display_name: str - score: int + + leader_score: int + + base_date: Optional[datetime] + modified_date: datetime + + prompts: int + replies_assistant: int + replies_prompter: int + labels_simple: int + labels_full: int + rankings_total: int + rankings_good: int + + accepted_prompts: int + accepted_replies_assistant: int + accepted_replies_prompter: int + + reply_ranked_1: int + reply_ranked_2: int + reply_ranked_3: int + + # only used for time frame "total" + streak_last_day_date: Optional[datetime] + streak_days: Optional[int] class LeaderboardStats(BaseModel): + time_frame: str leaderboard: List[UserScore] diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py index b99bb7ed..90ba8c8f 100644 --- a/oasst-shared/oasst_shared/utils.py +++ b/oasst-shared/oasst_shared/utils.py @@ -1,6 +1,32 @@ +import time from datetime import datetime, timezone +from functools import wraps + +from loguru import logger def utcnow() -> datetime: """Return the current utc date and time with tzinfo set to UTC.""" return datetime.now(timezone.utc) + + +def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None: + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + elapsed = end - start + if log_kwargs: + kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) + logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s") + else: + logger.log(level, f"Function '{func.__name__}' executed in {elapsed:f} s") + return result + + return wrapped + + if func and callable(func): + return decorator(func) + return decorator