Add leaderboard stats, periodic updates via fastapi-utils (#724)

* add leaderboard stats, periodic update via fastapi-utils

* count label tasks for assistant and prompter replies

* Daily stats update every 15 mins, simplify leaderboard endpoint

* add indices for some created_date columns

* make user stats update intervals configurable

* make sure intervals are positive
This commit is contained in:
Andreas Köpf
2023-01-15 12:04:19 +01:00
committed by GitHub
parent e01f2eb4ab
commit b5bb5bb7c0
21 changed files with 555 additions and 61 deletions
@@ -0,0 +1,83 @@
"""change user_stats ranking counts
Revision ID: 7c98102efbca
Revises: 619255ae9076
Create Date: 2023-01-15 00:02:45.622986
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "7c98102efbca"
down_revision = "619255ae9076"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_stats")
op.create_table(
"user_stats",
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("base_date", sa.DateTime(), nullable=True),
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("leader_score", sa.Integer(), nullable=False),
sa.Column("prompts", sa.Integer(), nullable=False),
sa.Column("replies_assistant", sa.Integer(), nullable=False),
sa.Column("replies_prompter", sa.Integer(), nullable=False),
sa.Column("labels_simple", sa.Integer(), nullable=False),
sa.Column("labels_full", sa.Integer(), nullable=False),
sa.Column("rankings_total", sa.Integer(), nullable=False),
sa.Column("rankings_good", sa.Integer(), nullable=False),
sa.Column("accepted_prompts", sa.Integer(), nullable=False),
sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False),
sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False),
sa.Column("reply_ranked_1", sa.Integer(), nullable=False),
sa.Column("reply_ranked_2", sa.Integer(), nullable=False),
sa.Column("reply_ranked_3", sa.Integer(), nullable=False),
sa.Column("streak_last_day_date", sa.DateTime(), nullable=True),
sa.Column("streak_days", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("user_id", "time_frame"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.add_column(
"user_stats",
sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
)
op.drop_column("user_stats", "reply_ranked_3")
op.drop_column("user_stats", "reply_ranked_2")
op.drop_column("user_stats", "reply_ranked_1")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""add indices for created_date
Revision ID: 423557e869e4
Revises: 7c98102efbca
Create Date: 2023-01-15 11:39:10.407859
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "423557e869e4"
down_revision = "7c98102efbca"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False)
op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False)
op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels")
op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction")
op.drop_index(op.f("ix_message_created_date"), table_name="message")
# ### end Alembic commands ###
+46
View File
@@ -9,6 +9,7 @@ import alembic.config
import fastapi
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from fastapi_utils.tasks import repeat_every
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
@@ -18,6 +19,7 @@ from oasst_backend.database import engine
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -195,6 +197,50 @@ def ensure_tree_states():
logger.exception("TreeManager.ensure_tree_states() failed.")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
def update_leader_board_day() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.day)
except Exception:
logger.exception("Error during leaderboard update (daily)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
def update_leader_board_week() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.week)
except Exception:
logger.exception("Error during user states update (weekly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
def update_leader_board_month() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.month)
except Exception:
logger.exception("Error during user states update (monthly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
def update_leader_board_total() -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.total)
except Exception:
logger.exception("Error during user states update (total)")
app.include_router(api_router, prefix=settings.API_V1_STR)
+1 -1
View File
@@ -19,5 +19,5 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages",
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
+11 -16
View File
@@ -1,26 +1,21 @@
from fastapi import APIRouter, Depends
from typing import Optional
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_repository import UserRepository
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@router.get("/create/assistant")
def get_assistant_leaderboard(
@router.get("/{time_frame}")
def get_leaderboard_day(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="prompter")
usr = UserStatsRepository(db)
return usr.get_leader_board(time_frame, limit=max_count)
+16
View File
@@ -109,6 +109,22 @@ class Settings(BaseSettings):
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
USER_STATS_INTERVAL_DAY: int = 15 # minutes
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
@validator(
"USER_STATS_INTERVAL_DAY",
"USER_STATS_INTERVAL_WEEK",
"USER_STATS_INTERVAL_MONTH",
"USER_STATS_INTERVAL_TOTAL",
)
def validate_user_stats_intervals(cls, v: int):
if v < 1:
raise ValueError(v)
return v
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
+2 -1
View File
@@ -8,12 +8,13 @@ from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .user import User
from .user_stats import UserStats
from .user_stats import UserStats, UserStatsTimeFrame
__all__ = [
"ApiClient",
"User",
"UserStats",
"UserStatsTimeFrame",
"Message",
"MessageEmbedding",
"MessageReaction",
@@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload):
type: Literal["message_ranking"] = "message_ranking"
ranking: list[int]
ranked_message_ids: list[UUID]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
@payload_type
class RankConversationRepliesPayload(TaskPayload):
conversation: protocol_schema.Conversation # the conversation so far
reply_messages: list[protocol_schema.ConversationMessage]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
@payload_type
@@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload):
prompt: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
@payload_type
@@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload):
reply: str
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
@payload_type
+1 -1
View File
@@ -30,7 +30,7 @@ class Message(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
)
payload_type: str = Field(nullable=False, max_length=200)
payload: Optional[PayloadContainer] = Field(
@@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True):
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
+2 -1
View File
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_shared.utils import utcnow
from sqlalchemy import false
from sqlmodel import Field, SQLModel
@@ -35,4 +36,4 @@ class Task(SQLModel, table=True):
@property
def expired(self) -> bool:
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
return self.expiry_date is not None and utcnow() > self.expiry_date
+1 -1
View File
@@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True):
)
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
+21 -7
View File
@@ -22,6 +22,7 @@ class UserStats(SQLModel, table=True):
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
leader_score: int = 0
modified_date: Optional[datetime] = Field(
@@ -40,14 +41,27 @@ class UserStats(SQLModel, table=True):
accepted_replies_assistant: int = 0
accepted_replies_prompter: int = 0
reply_assistant_ranked_1: int = 0
reply_assistant_ranked_2: int = 0
reply_assistant_ranked_3: int = 0
reply_prompter_ranked_1: int = 0
reply_prompter_ranked_2: int = 0
reply_prompter_ranked_3: int = 0
reply_ranked_1: int = 0
reply_ranked_2: int = 0
reply_ranked_3: int = 0
# only used for time span "total"
streak_last_day_date: Optional[datetime] = Field(nullable=True)
streak_days: Optional[int] = Field(nullable=True)
def compute_leader_score(self) -> int:
return (
self.prompts
+ self.replies_assistant * 4
+ self.replies_prompter
+ self.labels_simple
+ self.labels_full * 2
+ self.rankings_total
+ self.rankings_good
+ self.accepted_prompts
+ self.accepted_replies_assistant * 4
+ self.accepted_replies_prompter
+ self.reply_ranked_1 * 9
+ self.reply_ranked_2 * 3
+ self.reply_ranked_3
)
+4 -1
View File
@@ -260,7 +260,10 @@ class PromptRepository:
self.db.add(message)
reaction_payload = db_payload.RankingReactionPayload(
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
ranking=ranking.ranking,
ranked_message_ids=ranked_message_ids,
ranking_parent_id=task_payload.ranking_parent_id,
message_tree_id=task_payload.message_tree_id,
)
reaction = self.insert_reaction(task.id, reaction_payload)
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
+18 -3
View File
@@ -67,17 +67,30 @@ class TaskRepository:
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
type=task.type,
conversation=task.conversation,
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
type=task.type,
conversation=task.conversation,
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
type=task.type,
message_id=task.message_id,
prompt=task.prompt,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case protocol_schema.LabelPrompterReplyTask:
@@ -88,6 +101,7 @@ class TaskRepository:
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case protocol_schema.LabelAssistantReplyTask:
@@ -98,6 +112,7 @@ class TaskRepository:
reply=task.reply,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
)
case _:
+13 -2
View File
@@ -207,6 +207,9 @@ class TreeManager:
ranking_parent_id = random.choice(incomplete_rankings).parent_id
messages = self.pr.fetch_message_conversation(ranking_parent_id)
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
@@ -218,12 +221,20 @@ class TreeManager:
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=conversation, replies=replies, reply_messages=reply_messages
conversation=conversation,
replies=replies,
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
)
parent_message_id = ranking_parent_id
+2 -24
View File
@@ -1,11 +1,10 @@
from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Message, User
from oasst_backend.models import ApiClient, User
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -137,27 +136,6 @@ class UserRepository:
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
def query_users(
self,
api_client_id: Optional[UUID] = None,
@@ -0,0 +1,240 @@
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
LabelPrompterReplyPayload,
RankingReactionPayload,
)
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlmodel import Session, delete, func
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
"""
Get leaderboard stats for the specified time frame
"""
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value)
.order_by(UserStats.leader_score.desc())
.limit(limit)
)
def create_user_score(user_rank: int, r):
d = r["UserStats"].dict()
for k in ["user_id", "username", "auth_method", "display_name"]:
d[k] = r[k]
return UserScore(user_rank=user_rank, **d)
leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))]
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
qry = self.session.query(Message.user_id, func.count()).filter(
Message.deleted == sa.false(), Message.parent_id.is_(None)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
if only_reviewed:
qry = qry.filter(Message.review_result == sa.true())
qry = qry.group_by(Message.user_id)
return qry
def query_replies_by_role_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
) -> list:
qry = self.session.query(Message.user_id, Message.role, func.count()).filter(
Message.deleted == sa.false(), Message.parent_id.is_not(None)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
if only_reviewed:
qry = qry.filter(Message.review_result == sa.true())
qry = qry.group_by(Message.user_id, Message.role)
return qry
def query_labels_by_mode_per_user(
self, payload_type: str = LabelAssistantReplyPayload.__name__, reference_time: Optional[datetime] = None
):
qry = self.session.query(Task.user_id, Task.payload["payload", "mode"].astext, func.count()).filter(
Task.done == sa.true(), Task.payload_type == payload_type
)
if reference_time:
qry = qry.filter(Task.created_date >= reference_time)
qry = qry.group_by(Task.user_id, Task.payload["payload", "mode"].astext)
return qry
def query_rankings_per_user(self, reference_time: Optional[datetime] = None):
qry = self.session.query(MessageReaction.user_id, func.count()).filter(
MessageReaction.payload_type == RankingReactionPayload.__name__
)
if reference_time:
qry = qry.filter(MessageReaction.created_date >= reference_time)
qry = qry.group_by(MessageReaction.user_id)
return qry
def query_ranking_result_users(self, rank: int = 0, reference_time: Optional[datetime] = None):
ranked_message_id = MessageReaction.payload["payload", "ranked_message_ids", rank].astext.cast(
postgresql.UUID(as_uuid=True)
)
qry = (
self.session.query(Message.user_id, func.count())
.select_from(MessageReaction)
.join(Message, ranked_message_id == Message.id)
.filter(MessageReaction.payload_type == RankingReactionPayload.__name__)
)
if reference_time:
qry = qry.filter(MessageReaction.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None):
# gather user data
stats_by_user: dict[UUID, UserStats] = dict()
now = utcnow()
def get_stats(id: UUID) -> UserStats:
us = stats_by_user.get(id)
if not us:
us = UserStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
stats_by_user[id] = us
return us
# total prompts
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=False)
for r in qry:
uid, count = r
get_stats(uid).prompts = count
# accepted prompts
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=True)
for r in qry:
uid, count = r
get_stats(uid).accepted_prompts = count
# total replies
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=False)
for r in qry:
uid, role, count = r
s = get_stats(uid)
if role == "assistant":
s.replies_assistant += count
elif role == "prompter":
s.replies_prompter += count
# accepted replies
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=True)
for r in qry:
uid, role, count = r
s = get_stats(uid)
if role == "assistant":
s.accepted_replies_assistant += count
elif role == "prompter":
s.accepted_replies_prompter += count
# simple and full labels
qry = self.query_labels_by_mode_per_user(
payload_type=LabelAssistantReplyPayload.__name__, reference_time=base_date
)
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
s.labels_simple = count
elif mode == "full":
s.labels_full = count
qry = self.query_labels_by_mode_per_user(
payload_type=LabelPrompterReplyPayload.__name__, reference_time=base_date
)
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
s.labels_simple += count
elif mode == "full":
s.labels_full += count
qry = self.query_rankings_per_user(reference_time=base_date)
for r in qry:
uid, count = r
get_stats(uid).rankings_total = count
rank_field_names = ["reply_ranked_1", "reply_ranked_2", "reply_ranked_3"]
for i, fn in enumerate(rank_field_names):
qry = self.query_ranking_result_users(reference_time=base_date, rank=0)
for r in qry:
uid, count = r
setattr(get_stats(uid), fn, count)
# delete all existing stast for time frame
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
self.session.execute(d)
# compute magic leader score
for v in stats_by_user.values():
v.leader_score = v.compute_leader_score()
# insert user objects
self.session.add_all(stats_by_user.values())
self.session.flush()
def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame_key, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
def update_stats(self, *, time_frame: UserStatsTimeFrame):
now = utcnow()
match time_frame:
case UserStatsTimeFrame.day:
r = now - timedelta(days=1)
self.update_stats_time_frame(time_frame.value, r)
case UserStatsTimeFrame.week:
r = now.date() - timedelta(days=7)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame.value, r)
case UserStatsTimeFrame.month:
r = now.date() - timedelta(days=30)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame.value, r)
case UserStatsTimeFrame.total:
self.update_stats_time_frame(time_frame.value, None)
@log_timing(level="INFO")
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
for t in time_frames:
self.update_stats(time_frame=t)
@log_timing(level="INFO")
def update_all_time_frames(self):
self.update_multiple_time_frames(list(UserStatsTimeFrame))
if __name__ == "__main__":
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.database import engine
with Session(engine) as session:
api_client = get_dummy_api_client(session)
usr = UserStatsRepository(session)
usr.update_all_time_frames()
usr.get_leader_board(UserStatsTimeFrame.total)
+1
View File
@@ -1,6 +1,7 @@
alembic==1.8.1
fastapi==0.88.0
fastapi-limiter==0.1.5
fastapi-utils==0.2.1
loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
+30 -2
View File
@@ -169,6 +169,8 @@ class RankConversationRepliesTask(Task):
conversation: Conversation # the conversation so far
replies: list[str] # deprecated, use reply_messages
reply_messages: list[ConversationMessage]
message_tree_id: UUID
ranking_parent_id: UUID
class RankPrompterRepliesTask(RankConversationRepliesTask):
@@ -356,14 +358,40 @@ class SystemStats(BaseModel):
class UserScore(BaseModel):
ranking: int
user_rank: int
user_id: UUID
username: str
auth_method: str
display_name: str
score: int
leader_score: int
base_date: Optional[datetime]
modified_date: datetime
prompts: int
replies_assistant: int
replies_prompter: int
labels_simple: int
labels_full: int
rankings_total: int
rankings_good: int
accepted_prompts: int
accepted_replies_assistant: int
accepted_replies_prompter: int
reply_ranked_1: int
reply_ranked_2: int
reply_ranked_3: int
# only used for time frame "total"
streak_last_day_date: Optional[datetime]
streak_days: Optional[int]
class LeaderboardStats(BaseModel):
time_frame: str
leaderboard: List[UserScore]
+26
View File
@@ -1,6 +1,32 @@
import time
from datetime import datetime, timezone
from functools import wraps
from loguru import logger
def utcnow() -> datetime:
"""Return the current utc date and time with tzinfo set to UTC."""
return datetime.now(timezone.utc)
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
elapsed = end - start
if log_kwargs:
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
else:
logger.log(level, f"Function '{func.__name__}' executed in {elapsed:f} s")
return result
return wrapped
if func and callable(func):
return decorator(func)
return decorator