From e58ffd64fa028945a6b93812620c4b07b9ec619b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 15 Jan 2023 21:24:15 +0100 Subject: [PATCH] add /api/v1/users/{user_id}/stats endpoint (#744) * add /api/v1/users/{user_id}/stats endpoint * return 0 stats and add /api/v1/users/{user_id}/stats/{time_frame} * use utcnow() as modified date for 0 stats --- ...170d_add_rank_and_indices_to_user_stats.py | 33 +++++++ backend/main.py | 4 + backend/oasst_backend/api/v1/leaderboards.py | 2 +- backend/oasst_backend/api/v1/users.py | 22 +++++ backend/oasst_backend/models/user_stats.py | 10 +- .../oasst_backend/user_stats_repository.py | 99 +++++++++++++++---- oasst-shared/oasst_shared/schemas/protocol.py | 32 +++--- 7 files changed, 165 insertions(+), 37 deletions(-) create mode 100644 backend/alembic/versions/2023_01_15_1654-0964ac95170d_add_rank_and_indices_to_user_stats.py diff --git a/backend/alembic/versions/2023_01_15_1654-0964ac95170d_add_rank_and_indices_to_user_stats.py b/backend/alembic/versions/2023_01_15_1654-0964ac95170d_add_rank_and_indices_to_user_stats.py new file mode 100644 index 00000000..8e4abe49 --- /dev/null +++ b/backend/alembic/versions/2023_01_15_1654-0964ac95170d_add_rank_and_indices_to_user_stats.py @@ -0,0 +1,33 @@ +"""add rank and indices to user_stats + +Revision ID: 0964ac95170d +Revises: 423557e869e4 +Create Date: 2023-01-15 16:54:09.510018 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0964ac95170d" +down_revision = "423557e869e4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("user_stats", sa.Column("rank", sa.Integer(), nullable=True)) + op.create_index( + "ix_user_stats__timeframe__rank__user_id", "user_stats", ["time_frame", "rank", "user_id"], unique=True + ) + op.create_index("ix_user_stats__timeframe__user_id", "user_stats", ["time_frame", "user_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_user_stats__timeframe__user_id", table_name="user_stats") + op.drop_index("ix_user_stats__timeframe__rank__user_id", table_name="user_stats") + op.drop_column("user_stats", "rank") + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index fb7aa26b..3787100e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -204,6 +204,7 @@ def update_leader_board_day() -> None: with Session(engine) as session: usr = UserStatsRepository(session) usr.update_stats(time_frame=UserStatsTimeFrame.day) + session.commit() except Exception: logger.exception("Error during leaderboard update (daily)") @@ -215,6 +216,7 @@ def update_leader_board_week() -> None: with Session(engine) as session: usr = UserStatsRepository(session) usr.update_stats(time_frame=UserStatsTimeFrame.week) + session.commit() except Exception: logger.exception("Error during user states update (weekly)") @@ -226,6 +228,7 @@ def update_leader_board_month() -> None: with Session(engine) as session: usr = UserStatsRepository(session) usr.update_stats(time_frame=UserStatsTimeFrame.month) + session.commit() except Exception: logger.exception("Error during user states update (monthly)") @@ -237,6 +240,7 @@ def update_leader_board_total() -> None: with Session(engine) as session: usr = UserStatsRepository(session) usr.update_stats(time_frame=UserStatsTimeFrame.total) + session.commit() except Exception: logger.exception("Error during user states update (total)") diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 6341df36..213855a1 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -18,4 +18,4 @@ def get_leaderboard( db: Session = Depends(deps.get_db), ) -> LeaderboardStats: usr = UserStatsRepository(db) - return usr.get_leader_board(time_frame, limit=max_count) + return usr.get_leaderboard(time_frame, limit=max_count) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 565499a7..36cd65c9 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -8,6 +8,7 @@ from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient, User from oasst_backend.prompt_repository import PromptRepository from oasst_backend.user_repository import UserRepository +from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -96,3 +97,24 @@ def mark_user_messages_deleted( pr = PromptRepository(db, api_client) messages = pr.query_messages(user_id=user_id) pr.mark_messages_deleted(messages) + + +@router.get("/{user_id}/stats", response_model=dict[str, protocol.UserScore | None]) +def query_user_stats( + user_id: UUID, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + usr = UserStatsRepository(db) + return usr.get_user_stats_all_time_frames(user_id=user_id) + + +@router.get("/{user_id}/stats/{time_frame}", response_model=protocol.UserScore) +def query_user_stats_timeframe( + user_id: UUID, + time_frame: UserStatsTimeFrame, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + usr = UserStatsRepository(db) + return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value] diff --git a/backend/oasst_backend/models/user_stats.py b/backend/oasst_backend/models/user_stats.py index 393e3851..5ba9dcdb 100644 --- a/backend/oasst_backend/models/user_stats.py +++ b/backend/oasst_backend/models/user_stats.py @@ -5,7 +5,7 @@ from uuid import UUID import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg -from sqlmodel import Field, SQLModel +from sqlmodel import Field, Index, SQLModel class UserStatsTimeFrame(str, Enum): @@ -17,11 +17,15 @@ class UserStatsTimeFrame(str, Enum): class UserStats(SQLModel, table=True): __tablename__ = "user_stats" + __table_args__ = ( + Index("ix_user_stats__timeframe__user_id", "time_frame", "user_id", unique=True), + Index("ix_user_stats__timeframe__rank__user_id", "time_frame", "rank", "user_id", unique=True), + ) + time_frame: Optional[str] = Field(nullable=False, primary_key=True) user_id: Optional[UUID] = Field( sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True) ) - time_frame: Optional[str] = Field(nullable=False, primary_key=True) base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) leader_score: int = 0 @@ -29,6 +33,8 @@ class UserStats(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) ) + rank: int = Field(nullable=True) + prompts: int = 0 replies_assistant: int = 0 replies_prompter: int = 0 diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 8f047ab3..bdd0e2e9 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -3,6 +3,7 @@ from typing import Optional from uuid import UUID import sqlalchemy as sa +from loguru import logger from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame from oasst_backend.models.db_payload import ( LabelAssistantReplyPayload, @@ -12,14 +13,24 @@ from oasst_backend.models.db_payload import ( from oasst_shared.schemas.protocol import LeaderboardStats, UserScore from oasst_shared.utils import log_timing, utcnow from sqlalchemy.dialects import postgresql -from sqlmodel import Session, delete, func +from sqlmodel import Session, delete, func, text + + +def _create_user_score(r): + if r["UserStats"]: + d = r["UserStats"].dict() + else: + d = {"modified_date": utcnow()} + for k in ["user_id", "username", "auth_method", "display_name"]: + d[k] = r[k] + return UserScore(**d) class UserStatsRepository: def __init__(self, session: Session): self.session = session - def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats: + def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats: """ Get leaderboard stats for the specified time frame """ @@ -32,15 +43,25 @@ class UserStatsRepository: .limit(limit) ) - def create_user_score(user_rank: int, r): - d = r["UserStats"].dict() - for k in ["user_id", "username", "auth_method", "display_name"]: - d[k] = r[k] - return UserScore(user_rank=user_rank, **d) - - leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))] + leaderboard = [_create_user_score(r) for r in self.session.exec(qry)] return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard) + def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]: + qry = ( + self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats) + .outerjoin(UserStats, User.id == UserStats.user_id) + .filter(User.id == user_id) + ) + + stats_by_timeframe = {} + for r in self.session.exec(qry): + us = r["UserStats"] + if us is not None: + stats_by_timeframe[us.time_frame] = _create_user_score(r) + else: + stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame} + return stats_by_timeframe + def query_total_prompts_per_user( self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True ): @@ -102,9 +123,11 @@ class UserStatsRepository: qry = qry.group_by(Message.user_id) return qry - def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None): + def _update_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None): # gather user data + time_frame_key = time_frame.value + stats_by_user: dict[UUID, UserStats] = dict() now = utcnow() @@ -194,8 +217,46 @@ class UserStatsRepository: self.session.add_all(stats_by_user.values()) self.session.flush() - def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None): - self._update_stats_internal(time_frame_key, reference_time) + self.update_ranks(time_frame=time_frame) + + @log_timing(log_kwargs=True) + def update_ranks(self, time_frame: UserStatsTimeFrame = None): + """ + Update user_stats ranks. The persisted rank values allow to + quickly the rank of a single user and to query nearby users. + """ + + # todo: convert sql to sqlalchemy query.. + # ranks = self.session.query( + # func.row_number() + # .over(partition_by=UserStats.time_frame, order_by=[UserStats.leader_score.desc(), UserStats.user_id]) + # .label("rank"), + # UserStats.user_id, + # UserStats.time_frame, + # ) + + sql_update_rank = """ +-- update rank +UPDATE user_stats us +SET "rank" = r."rank" +FROM + (SELECT + ROW_NUMBER () OVER( + PARTITION BY time_frame + ORDER BY leader_score DESC, user_id + ) AS "rank", user_id, time_frame + FROM user_stats + WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r +WHERE + us.user_id = r.user_id + AND us.time_frame = r.time_frame;""" + r = self.session.execute( + text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None} + ) + logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.") + + def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None): + self._update_stats_internal(time_frame, reference_time) self.session.commit() @log_timing(log_kwargs=True, level="INFO") @@ -204,20 +265,20 @@ class UserStatsRepository: match time_frame: case UserStatsTimeFrame.day: r = now - timedelta(days=1) - self.update_stats_time_frame(time_frame.value, r) + self.update_stats_time_frame(time_frame, r) case UserStatsTimeFrame.week: r = now.date() - timedelta(days=7) r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo) - self.update_stats_time_frame(time_frame.value, r) + self.update_stats_time_frame(time_frame, r) case UserStatsTimeFrame.month: r = now.date() - timedelta(days=30) r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo) - self.update_stats_time_frame(time_frame.value, r) + self.update_stats_time_frame(time_frame, r) case UserStatsTimeFrame.total: - self.update_stats_time_frame(time_frame.value, None) + self.update_stats_time_frame(time_frame, None) @log_timing(level="INFO") def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]): @@ -236,5 +297,7 @@ if __name__ == "__main__": with Session(engine) as session: api_client = get_dummy_api_client(session) usr = UserStatsRepository(session) - usr.update_all_time_frames() - usr.get_leader_board(UserStatsTimeFrame.total) + # usr.update_all_time_frames() + # session.commit() + # usr.get_leader_board(UserStatsTimeFrame.total) + usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce")) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 23b19e71..006a6026 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -358,32 +358,32 @@ class SystemStats(BaseModel): class UserScore(BaseModel): - user_rank: int + rank: Optional[int] user_id: UUID username: str auth_method: str display_name: str - leader_score: int + leader_score: int = 0 base_date: Optional[datetime] - modified_date: datetime + modified_date: Optional[datetime] - prompts: int - replies_assistant: int - replies_prompter: int - labels_simple: int - labels_full: int - rankings_total: int - rankings_good: int + prompts: int = 0 + replies_assistant: int = 0 + replies_prompter: int = 0 + labels_simple: int = 0 + labels_full: int = 0 + rankings_total: int = 0 + rankings_good: int = 0 - accepted_prompts: int - accepted_replies_assistant: int - accepted_replies_prompter: int + accepted_prompts: int = 0 + accepted_replies_assistant: int = 0 + accepted_replies_prompter: int = 0 - reply_ranked_1: int - reply_ranked_2: int - reply_ranked_3: int + reply_ranked_1: int = 0 + reply_ranked_2: int = 0 + reply_ranked_3: int = 0 # only used for time frame "total" streak_last_day_date: Optional[datetime]