mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add /api/v1/users/{user_id}/stats endpoint (#744)
* add /api/v1/users/{user_id}/stats endpoint
* return 0 stats and add /api/v1/users/{user_id}/stats/{time_frame}
* use utcnow() as modified date for 0 stats
This commit is contained in:
+33
@@ -0,0 +1,33 @@
|
||||
"""add rank and indices to user_stats
|
||||
|
||||
Revision ID: 0964ac95170d
|
||||
Revises: 423557e869e4
|
||||
Create Date: 2023-01-15 16:54:09.510018
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0964ac95170d"
|
||||
down_revision = "423557e869e4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user_stats", sa.Column("rank", sa.Integer(), nullable=True))
|
||||
op.create_index(
|
||||
"ix_user_stats__timeframe__rank__user_id", "user_stats", ["time_frame", "rank", "user_id"], unique=True
|
||||
)
|
||||
op.create_index("ix_user_stats__timeframe__user_id", "user_stats", ["time_frame", "user_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_user_stats__timeframe__user_id", table_name="user_stats")
|
||||
op.drop_index("ix_user_stats__timeframe__rank__user_id", table_name="user_stats")
|
||||
op.drop_column("user_stats", "rank")
|
||||
# ### end Alembic commands ###
|
||||
@@ -204,6 +204,7 @@ def update_leader_board_day() -> None:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during leaderboard update (daily)")
|
||||
|
||||
@@ -215,6 +216,7 @@ def update_leader_board_week() -> None:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (weekly)")
|
||||
|
||||
@@ -226,6 +228,7 @@ def update_leader_board_month() -> None:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (monthly)")
|
||||
|
||||
@@ -237,6 +240,7 @@ def update_leader_board_total() -> None:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (total)")
|
||||
|
||||
|
||||
@@ -18,4 +18,4 @@ def get_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leader_board(time_frame, limit=max_count)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count)
|
||||
|
||||
@@ -8,6 +8,7 @@ from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -96,3 +97,24 @@ def mark_user_messages_deleted(
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats", response_model=dict[str, protocol.UserScore | None])
|
||||
def query_user_stats(
|
||||
user_id: UUID,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_user_stats_all_time_frames(user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats/{time_frame}", response_model=protocol.UserScore)
|
||||
def query_user_stats_timeframe(
|
||||
user_id: UUID,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class UserStatsTimeFrame(str, Enum):
|
||||
@@ -17,11 +17,15 @@ class UserStatsTimeFrame(str, Enum):
|
||||
|
||||
class UserStats(SQLModel, table=True):
|
||||
__tablename__ = "user_stats"
|
||||
__table_args__ = (
|
||||
Index("ix_user_stats__timeframe__user_id", "time_frame", "user_id", unique=True),
|
||||
Index("ix_user_stats__timeframe__rank__user_id", "time_frame", "rank", "user_id", unique=True),
|
||||
)
|
||||
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
|
||||
leader_score: int = 0
|
||||
@@ -29,6 +33,8 @@ class UserStats(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
prompts: int = 0
|
||||
replies_assistant: int = 0
|
||||
replies_prompter: int = 0
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
@@ -12,14 +13,24 @@ from oasst_backend.models.db_payload import (
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlmodel import Session, delete, func
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r):
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
d = {"modified_date": utcnow()}
|
||||
for k in ["user_id", "username", "auth_method", "display_name"]:
|
||||
d[k] = r[k]
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
|
||||
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for the specified time frame
|
||||
"""
|
||||
@@ -32,15 +43,25 @@ class UserStatsRepository:
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
def create_user_score(user_rank: int, r):
|
||||
d = r["UserStats"].dict()
|
||||
for k in ["user_id", "username", "auth_method", "display_name"]:
|
||||
d[k] = r[k]
|
||||
return UserScore(user_rank=user_rank, **d)
|
||||
|
||||
leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))]
|
||||
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
|
||||
|
||||
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.outerjoin(UserStats, User.id == UserStats.user_id)
|
||||
.filter(User.id == user_id)
|
||||
)
|
||||
|
||||
stats_by_timeframe = {}
|
||||
for r in self.session.exec(qry):
|
||||
us = r["UserStats"]
|
||||
if us is not None:
|
||||
stats_by_timeframe[us.time_frame] = _create_user_score(r)
|
||||
else:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
@@ -102,9 +123,11 @@ class UserStatsRepository:
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None):
|
||||
def _update_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
time_frame_key = time_frame.value
|
||||
|
||||
stats_by_user: dict[UUID, UserStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
@@ -194,8 +217,46 @@ class UserStatsRepository:
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame_key, reference_time)
|
||||
self.update_ranks(time_frame=time_frame)
|
||||
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
"""
|
||||
Update user_stats ranks. The persisted rank values allow to
|
||||
quickly the rank of a single user and to query nearby users.
|
||||
"""
|
||||
|
||||
# todo: convert sql to sqlalchemy query..
|
||||
# ranks = self.session.query(
|
||||
# func.row_number()
|
||||
# .over(partition_by=UserStats.time_frame, order_by=[UserStats.leader_score.desc(), UserStats.user_id])
|
||||
# .label("rank"),
|
||||
# UserStats.user_id,
|
||||
# UserStats.time_frame,
|
||||
# )
|
||||
|
||||
sql_update_rank = """
|
||||
-- update rank
|
||||
UPDATE user_stats us
|
||||
SET "rank" = r."rank"
|
||||
FROM
|
||||
(SELECT
|
||||
ROW_NUMBER () OVER(
|
||||
PARTITION BY time_frame
|
||||
ORDER BY leader_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM user_stats
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
us.user_id = r.user_id
|
||||
AND us.time_frame = r.time_frame;"""
|
||||
r = self.session.execute(
|
||||
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
@@ -204,20 +265,20 @@ class UserStatsRepository:
|
||||
match time_frame:
|
||||
case UserStatsTimeFrame.day:
|
||||
r = now - timedelta(days=1)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.week:
|
||||
r = now.date() - timedelta(days=7)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.month:
|
||||
r = now.date() - timedelta(days=30)
|
||||
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
|
||||
self.update_stats_time_frame(time_frame.value, r)
|
||||
self.update_stats_time_frame(time_frame, r)
|
||||
|
||||
case UserStatsTimeFrame.total:
|
||||
self.update_stats_time_frame(time_frame.value, None)
|
||||
self.update_stats_time_frame(time_frame, None)
|
||||
|
||||
@log_timing(level="INFO")
|
||||
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
|
||||
@@ -236,5 +297,7 @@ if __name__ == "__main__":
|
||||
with Session(engine) as session:
|
||||
api_client = get_dummy_api_client(session)
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_all_time_frames()
|
||||
usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
# usr.update_all_time_frames()
|
||||
# session.commit()
|
||||
# usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
|
||||
|
||||
@@ -358,32 +358,32 @@ class SystemStats(BaseModel):
|
||||
|
||||
|
||||
class UserScore(BaseModel):
|
||||
user_rank: int
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
|
||||
leader_score: int
|
||||
leader_score: int = 0
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: datetime
|
||||
modified_date: Optional[datetime]
|
||||
|
||||
prompts: int
|
||||
replies_assistant: int
|
||||
replies_prompter: int
|
||||
labels_simple: int
|
||||
labels_full: int
|
||||
rankings_total: int
|
||||
rankings_good: int
|
||||
prompts: int = 0
|
||||
replies_assistant: int = 0
|
||||
replies_prompter: int = 0
|
||||
labels_simple: int = 0
|
||||
labels_full: int = 0
|
||||
rankings_total: int = 0
|
||||
rankings_good: int = 0
|
||||
|
||||
accepted_prompts: int
|
||||
accepted_replies_assistant: int
|
||||
accepted_replies_prompter: int
|
||||
accepted_prompts: int = 0
|
||||
accepted_replies_assistant: int = 0
|
||||
accepted_replies_prompter: int = 0
|
||||
|
||||
reply_ranked_1: int
|
||||
reply_ranked_2: int
|
||||
reply_ranked_3: int
|
||||
reply_ranked_1: int = 0
|
||||
reply_ranked_2: int = 0
|
||||
reply_ranked_3: int = 0
|
||||
|
||||
# only used for time frame "total"
|
||||
streak_last_day_date: Optional[datetime]
|
||||
|
||||
Reference in New Issue
Block a user