mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add leaderboard stats, periodic updates via fastapi-utils (#724)
* add leaderboard stats, periodic update via fastapi-utils * count label tasks for assistant and prompter replies * Daily stats update every 15 mins, simplify leaderboard endpoint * add indices for some created_date columns * make user stats update intervals configurable * make sure intervals are positive
This commit is contained in:
+83
@@ -0,0 +1,83 @@
|
||||
"""change user_stats ranking counts
|
||||
|
||||
Revision ID: 7c98102efbca
|
||||
Revises: 619255ae9076
|
||||
Create Date: 2023-01-15 00:02:45.622986
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7c98102efbca"
|
||||
down_revision = "619255ae9076"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user_stats")
|
||||
op.create_table(
|
||||
"user_stats",
|
||||
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("base_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("leader_score", sa.Integer(), nullable=False),
|
||||
sa.Column("prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("replies_assistant", sa.Integer(), nullable=False),
|
||||
sa.Column("replies_prompter", sa.Integer(), nullable=False),
|
||||
sa.Column("labels_simple", sa.Integer(), nullable=False),
|
||||
sa.Column("labels_full", sa.Integer(), nullable=False),
|
||||
sa.Column("rankings_total", sa.Integer(), nullable=False),
|
||||
sa.Column("rankings_good", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False),
|
||||
sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_1", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_2", sa.Integer(), nullable=False),
|
||||
sa.Column("reply_ranked_3", sa.Integer(), nullable=False),
|
||||
sa.Column("streak_last_day_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("streak_days", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id", "time_frame"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"user_stats",
|
||||
sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
|
||||
)
|
||||
op.drop_column("user_stats", "reply_ranked_3")
|
||||
op.drop_column("user_stats", "reply_ranked_2")
|
||||
op.drop_column("user_stats", "reply_ranked_1")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indices for created_date
|
||||
|
||||
Revision ID: 423557e869e4
|
||||
Revises: 7c98102efbca
|
||||
Create Date: 2023-01-15 11:39:10.407859
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "423557e869e4"
|
||||
down_revision = "7c98102efbca"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels")
|
||||
op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction")
|
||||
op.drop_index(op.f("ix_message_created_date"), table_name="message")
|
||||
# ### end Alembic commands ###
|
||||
@@ -9,6 +9,7 @@ import alembic.config
|
||||
import fastapi
|
||||
import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
@@ -18,6 +19,7 @@ from oasst_backend.database import engine
|
||||
from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -195,6 +197,50 @@ def ensure_tree_states():
|
||||
logger.exception("TreeManager.ensure_tree_states() failed.")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
|
||||
def update_leader_board_day() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
except Exception:
|
||||
logger.exception("Error during leaderboard update (daily)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
|
||||
def update_leader_board_week() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (weekly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
|
||||
def update_leader_board_month() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (monthly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
|
||||
def update_leader_board_total() -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (total)")
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
|
||||
@@ -19,5 +19,5 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages",
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/create/assistant")
|
||||
def get_assistant_leaderboard(
|
||||
@router.get("/{time_frame}")
|
||||
def get_leaderboard_day(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="prompter")
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leader_board(time_frame, limit=max_count)
|
||||
|
||||
@@ -109,6 +109,22 @@ class Settings(BaseSettings):
|
||||
|
||||
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
|
||||
|
||||
USER_STATS_INTERVAL_DAY: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
|
||||
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
|
||||
|
||||
@validator(
|
||||
"USER_STATS_INTERVAL_DAY",
|
||||
"USER_STATS_INTERVAL_WEEK",
|
||||
"USER_STATS_INTERVAL_MONTH",
|
||||
"USER_STATS_INTERVAL_TOTAL",
|
||||
)
|
||||
def validate_user_stats_intervals(cls, v: int):
|
||||
if v < 1:
|
||||
raise ValueError(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@@ -8,12 +8,13 @@ from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .user import User
|
||||
from .user_stats import UserStats
|
||||
from .user_stats import UserStats, UserStatsTimeFrame
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"User",
|
||||
"UserStats",
|
||||
"UserStatsTimeFrame",
|
||||
"Message",
|
||||
"MessageEmbedding",
|
||||
"MessageReaction",
|
||||
|
||||
@@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
ranking: list[int]
|
||||
ranked_message_ids: list[UUID]
|
||||
ranking_parent_id: Optional[UUID]
|
||||
message_tree_id: Optional[UUID]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankConversationRepliesPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation # the conversation so far
|
||||
reply_messages: list[protocol_schema.ConversationMessage]
|
||||
ranking_parent_id: Optional[UUID]
|
||||
message_tree_id: Optional[UUID]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload):
|
||||
prompt: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
reply: str
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -30,7 +30,7 @@ class Message(SQLModel, table=True):
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
|
||||
@@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_shared.utils import utcnow
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
@@ -35,4 +36,4 @@ class Task(SQLModel, table=True):
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
|
||||
return self.expiry_date is not None and utcnow() > self.expiry_date
|
||||
|
||||
@@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True):
|
||||
)
|
||||
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
|
||||
@@ -22,6 +22,7 @@ class UserStats(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
@@ -40,14 +41,27 @@ class UserStats(SQLModel, table=True):
|
||||
accepted_replies_assistant: int = 0
|
||||
accepted_replies_prompter: int = 0
|
||||
|
||||
reply_assistant_ranked_1: int = 0
|
||||
reply_assistant_ranked_2: int = 0
|
||||
reply_assistant_ranked_3: int = 0
|
||||
|
||||
reply_prompter_ranked_1: int = 0
|
||||
reply_prompter_ranked_2: int = 0
|
||||
reply_prompter_ranked_3: int = 0
|
||||
reply_ranked_1: int = 0
|
||||
reply_ranked_2: int = 0
|
||||
reply_ranked_3: int = 0
|
||||
|
||||
# only used for time span "total"
|
||||
streak_last_day_date: Optional[datetime] = Field(nullable=True)
|
||||
streak_days: Optional[int] = Field(nullable=True)
|
||||
|
||||
def compute_leader_score(self) -> int:
|
||||
return (
|
||||
self.prompts
|
||||
+ self.replies_assistant * 4
|
||||
+ self.replies_prompter
|
||||
+ self.labels_simple
|
||||
+ self.labels_full * 2
|
||||
+ self.rankings_total
|
||||
+ self.rankings_good
|
||||
+ self.accepted_prompts
|
||||
+ self.accepted_replies_assistant * 4
|
||||
+ self.accepted_replies_prompter
|
||||
+ self.reply_ranked_1 * 9
|
||||
+ self.reply_ranked_2 * 3
|
||||
+ self.reply_ranked_3
|
||||
)
|
||||
|
||||
@@ -260,7 +260,10 @@ class PromptRepository:
|
||||
self.db.add(message)
|
||||
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
ranking=ranking.ranking,
|
||||
ranked_message_ids=ranked_message_ids,
|
||||
ranking_parent_id=task_payload.ranking_parent_id,
|
||||
message_tree_id=task_payload.message_tree_id,
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
|
||||
|
||||
@@ -67,17 +67,30 @@ class TaskRepository:
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
type=task.type,
|
||||
conversation=task.conversation,
|
||||
reply_messages=task.reply_messages,
|
||||
ranking_parent_id=task.ranking_parent_id,
|
||||
message_tree_id=task.message_tree_id,
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, reply_messages=task.reply_messages
|
||||
type=task.type,
|
||||
conversation=task.conversation,
|
||||
reply_messages=task.reply_messages,
|
||||
ranking_parent_id=task.ranking_parent_id,
|
||||
message_tree_id=task.message_tree_id,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
prompt=task.prompt,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
@@ -88,6 +101,7 @@ class TaskRepository:
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
@@ -98,6 +112,7 @@ class TaskRepository:
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
)
|
||||
|
||||
case _:
|
||||
|
||||
@@ -207,6 +207,9 @@ class TreeManager:
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
|
||||
@@ -218,12 +221,20 @@ class TreeManager:
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation, replies=replies, reply_messages=reply_messages
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -137,27 +136,6 @@ class UserRepository:
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
def query_users(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
LabelPrompterReplyPayload,
|
||||
RankingReactionPayload,
|
||||
)
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlmodel import Session, delete, func
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for the specified time frame
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value)
|
||||
.order_by(UserStats.leader_score.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
def create_user_score(user_rank: int, r):
|
||||
d = r["UserStats"].dict()
|
||||
for k in ["user_id", "username", "auth_method", "display_name"]:
|
||||
d[k] = r[k]
|
||||
return UserScore(user_rank=user_rank, **d)
|
||||
|
||||
leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))]
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
qry = self.session.query(Message.user_id, func.count()).filter(
|
||||
Message.deleted == sa.false(), Message.parent_id.is_(None)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
if only_reviewed:
|
||||
qry = qry.filter(Message.review_result == sa.true())
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_replies_by_role_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
) -> list:
|
||||
qry = self.session.query(Message.user_id, Message.role, func.count()).filter(
|
||||
Message.deleted == sa.false(), Message.parent_id.is_not(None)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
if only_reviewed:
|
||||
qry = qry.filter(Message.review_result == sa.true())
|
||||
qry = qry.group_by(Message.user_id, Message.role)
|
||||
return qry
|
||||
|
||||
def query_labels_by_mode_per_user(
|
||||
self, payload_type: str = LabelAssistantReplyPayload.__name__, reference_time: Optional[datetime] = None
|
||||
):
|
||||
qry = self.session.query(Task.user_id, Task.payload["payload", "mode"].astext, func.count()).filter(
|
||||
Task.done == sa.true(), Task.payload_type == payload_type
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(Task.created_date >= reference_time)
|
||||
qry = qry.group_by(Task.user_id, Task.payload["payload", "mode"].astext)
|
||||
return qry
|
||||
|
||||
def query_rankings_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = self.session.query(MessageReaction.user_id, func.count()).filter(
|
||||
MessageReaction.payload_type == RankingReactionPayload.__name__
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(MessageReaction.created_date >= reference_time)
|
||||
qry = qry.group_by(MessageReaction.user_id)
|
||||
return qry
|
||||
|
||||
def query_ranking_result_users(self, rank: int = 0, reference_time: Optional[datetime] = None):
|
||||
ranked_message_id = MessageReaction.payload["payload", "ranked_message_ids", rank].astext.cast(
|
||||
postgresql.UUID(as_uuid=True)
|
||||
)
|
||||
qry = (
|
||||
self.session.query(Message.user_id, func.count())
|
||||
.select_from(MessageReaction)
|
||||
.join(Message, ranked_message_id == Message.id)
|
||||
.filter(MessageReaction.payload_type == RankingReactionPayload.__name__)
|
||||
)
|
||||
if reference_time:
|
||||
qry = qry.filter(MessageReaction.created_date >= reference_time)
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
stats_by_user: dict[UUID, UserStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
def get_stats(id: UUID) -> UserStats:
|
||||
us = stats_by_user.get(id)
|
||||
if not us:
|
||||
us = UserStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
|
||||
stats_by_user[id] = us
|
||||
return us
|
||||
|
||||
# total prompts
|
||||
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=False)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).prompts = count
|
||||
|
||||
# accepted prompts
|
||||
qry = self.query_total_prompts_per_user(reference_time=base_date, only_reviewed=True)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).accepted_prompts = count
|
||||
|
||||
# total replies
|
||||
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=False)
|
||||
for r in qry:
|
||||
uid, role, count = r
|
||||
s = get_stats(uid)
|
||||
if role == "assistant":
|
||||
s.replies_assistant += count
|
||||
elif role == "prompter":
|
||||
s.replies_prompter += count
|
||||
|
||||
# accepted replies
|
||||
qry = self.query_replies_by_role_per_user(reference_time=base_date, only_reviewed=True)
|
||||
for r in qry:
|
||||
uid, role, count = r
|
||||
s = get_stats(uid)
|
||||
if role == "assistant":
|
||||
s.accepted_replies_assistant += count
|
||||
elif role == "prompter":
|
||||
s.accepted_replies_prompter += count
|
||||
|
||||
# simple and full labels
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
payload_type=LabelAssistantReplyPayload.__name__, reference_time=base_date
|
||||
)
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
s.labels_simple = count
|
||||
elif mode == "full":
|
||||
s.labels_full = count
|
||||
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
payload_type=LabelPrompterReplyPayload.__name__, reference_time=base_date
|
||||
)
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
s.labels_simple += count
|
||||
elif mode == "full":
|
||||
s.labels_full += count
|
||||
|
||||
qry = self.query_rankings_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
get_stats(uid).rankings_total = count
|
||||
|
||||
rank_field_names = ["reply_ranked_1", "reply_ranked_2", "reply_ranked_3"]
|
||||
for i, fn in enumerate(rank_field_names):
|
||||
qry = self.query_ranking_result_users(reference_time=base_date, rank=0)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
setattr(get_stats(uid), fn, count)
|
||||
|
||||
# delete all existing stast for time frame
|
||||
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.leader_score = v.compute_leader_score()
|
||||
|
||||
# insert user objects
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame_key, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
def update_stats(self, *, time_frame: UserStatsTimeFrame):
|
||||
now = utcnow()
|
||||
match time_frame:
|
||||
case UserStatsTimeFrame.day:
|
||||
r = now - timedelta(days=1)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
|
||||
case UserStatsTimeFrame.week:
|
||||
r = now.date() - timedelta(days=7)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
|
||||
case UserStatsTimeFrame.month:
|
||||
r = now.date() - timedelta(days=30)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
|
||||
case UserStatsTimeFrame.total:
|
||||
self.update_stats_time_frame(time_frame.value, None)
|
||||
|
||||
@log_timing(level="INFO")
|
||||
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
|
||||
for t in time_frames:
|
||||
self.update_stats(time_frame=t)
|
||||
|
||||
@log_timing(level="INFO")
|
||||
def update_all_time_frames(self):
|
||||
self.update_multiple_time_frames(list(UserStatsTimeFrame))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.database import engine
|
||||
|
||||
with Session(engine) as session:
|
||||
api_client = get_dummy_api_client(session)
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_all_time_frames()
|
||||
usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
@@ -1,6 +1,7 @@
|
||||
alembic==1.8.1
|
||||
fastapi==0.88.0
|
||||
fastapi-limiter==0.1.5
|
||||
fastapi-utils==0.2.1
|
||||
loguru==0.6.0
|
||||
numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
|
||||
@@ -169,6 +169,8 @@ class RankConversationRepliesTask(Task):
|
||||
conversation: Conversation # the conversation so far
|
||||
replies: list[str] # deprecated, use reply_messages
|
||||
reply_messages: list[ConversationMessage]
|
||||
message_tree_id: UUID
|
||||
ranking_parent_id: UUID
|
||||
|
||||
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
@@ -356,14 +358,40 @@ class SystemStats(BaseModel):
|
||||
|
||||
|
||||
class UserScore(BaseModel):
|
||||
ranking: int
|
||||
user_rank: int
|
||||
user_id: UUID
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
score: int
|
||||
|
||||
leader_score: int
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: datetime
|
||||
|
||||
prompts: int
|
||||
replies_assistant: int
|
||||
replies_prompter: int
|
||||
labels_simple: int
|
||||
labels_full: int
|
||||
rankings_total: int
|
||||
rankings_good: int
|
||||
|
||||
accepted_prompts: int
|
||||
accepted_replies_assistant: int
|
||||
accepted_replies_prompter: int
|
||||
|
||||
reply_ranked_1: int
|
||||
reply_ranked_2: int
|
||||
reply_ranked_3: int
|
||||
|
||||
# only used for time frame "total"
|
||||
streak_last_day_date: Optional[datetime]
|
||||
streak_days: Optional[int]
|
||||
|
||||
|
||||
class LeaderboardStats(BaseModel):
|
||||
time_frame: str
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,32 @@
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return the current utc date and time with tzinfo set to UTC."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
if log_kwargs:
|
||||
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
|
||||
else:
|
||||
logger.log(level, f"Function '{func.__name__}' executed in {elapsed:f} s")
|
||||
return result
|
||||
|
||||
return wrapped
|
||||
|
||||
if func and callable(func):
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user