add /api/v1/users/{user_id}/stats endpoint (#744)

* add /api/v1/users/{user_id}/stats endpoint

* return 0 stats and add /api/v1/users/{user_id}/stats/{time_frame}

* use utcnow() as modified date for 0 stats
This commit is contained in:
Andreas Köpf
2023-01-15 21:24:15 +01:00
committed by GitHub
parent ed80762182
commit e58ffd64fa
7 changed files with 165 additions and 37 deletions
@@ -0,0 +1,33 @@
"""add rank and indices to user_stats
Revision ID: 0964ac95170d
Revises: 423557e869e4
Create Date: 2023-01-15 16:54:09.510018
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "0964ac95170d"
down_revision = "423557e869e4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user_stats", sa.Column("rank", sa.Integer(), nullable=True))
op.create_index(
"ix_user_stats__timeframe__rank__user_id", "user_stats", ["time_frame", "rank", "user_id"], unique=True
)
op.create_index("ix_user_stats__timeframe__user_id", "user_stats", ["time_frame", "user_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_user_stats__timeframe__user_id", table_name="user_stats")
op.drop_index("ix_user_stats__timeframe__rank__user_id", table_name="user_stats")
op.drop_column("user_stats", "rank")
# ### end Alembic commands ###
+4
View File
@@ -204,6 +204,7 @@ def update_leader_board_day() -> None:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.day)
session.commit()
except Exception:
logger.exception("Error during leaderboard update (daily)")
@@ -215,6 +216,7 @@ def update_leader_board_week() -> None:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.week)
session.commit()
except Exception:
logger.exception("Error during user states update (weekly)")
@@ -226,6 +228,7 @@ def update_leader_board_month() -> None:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.month)
session.commit()
except Exception:
logger.exception("Error during user states update (monthly)")
@@ -237,6 +240,7 @@ def update_leader_board_total() -> None:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.total)
session.commit()
except Exception:
logger.exception("Error during user states update (total)")
+1 -1
View File
@@ -18,4 +18,4 @@ def get_leaderboard(
db: Session = Depends(deps.get_db),
) -> LeaderboardStats:
usr = UserStatsRepository(db)
return usr.get_leader_board(time_frame, limit=max_count)
return usr.get_leaderboard(time_frame, limit=max_count)
+22
View File
@@ -8,6 +8,7 @@ from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -96,3 +97,24 @@ def mark_user_messages_deleted(
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
@router.get("/{user_id}/stats", response_model=dict[str, protocol.UserScore | None])
def query_user_stats(
user_id: UUID,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
usr = UserStatsRepository(db)
return usr.get_user_stats_all_time_frames(user_id=user_id)
@router.get("/{user_id}/stats/{time_frame}", response_model=protocol.UserScore)
def query_user_stats_timeframe(
user_id: UUID,
time_frame: UserStatsTimeFrame,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
usr = UserStatsRepository(db)
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
+8 -2
View File
@@ -5,7 +5,7 @@ from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from sqlmodel import Field, Index, SQLModel
class UserStatsTimeFrame(str, Enum):
@@ -17,11 +17,15 @@ class UserStatsTimeFrame(str, Enum):
class UserStats(SQLModel, table=True):
__tablename__ = "user_stats"
__table_args__ = (
Index("ix_user_stats__timeframe__user_id", "time_frame", "user_id", unique=True),
Index("ix_user_stats__timeframe__rank__user_id", "time_frame", "rank", "user_id", unique=True),
)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
leader_score: int = 0
@@ -29,6 +33,8 @@ class UserStats(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
prompts: int = 0
replies_assistant: int = 0
replies_prompter: int = 0
+81 -18
View File
@@ -3,6 +3,7 @@ from typing import Optional
from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
@@ -12,14 +13,24 @@ from oasst_backend.models.db_payload import (
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlmodel import Session, delete, func
from sqlmodel import Session, delete, func, text
def _create_user_score(r):
if r["UserStats"]:
d = r["UserStats"].dict()
else:
d = {"modified_date": utcnow()}
for k in ["user_id", "username", "auth_method", "display_name"]:
d[k] = r[k]
return UserScore(**d)
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
def get_leader_board(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
"""
Get leaderboard stats for the specified time frame
"""
@@ -32,15 +43,25 @@ class UserStatsRepository:
.limit(limit)
)
def create_user_score(user_rank: int, r):
d = r["UserStats"].dict()
for k in ["user_id", "username", "auth_method", "display_name"]:
d[k] = r[k]
return UserScore(user_rank=user_rank, **d)
leaderboard = [create_user_score(i, r) for i, r in enumerate(self.session.exec(qry))]
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.outerjoin(UserStats, User.id == UserStats.user_id)
.filter(User.id == user_id)
)
stats_by_timeframe = {}
for r in self.session.exec(qry):
us = r["UserStats"]
if us is not None:
stats_by_timeframe[us.time_frame] = _create_user_score(r)
else:
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
@@ -102,9 +123,11 @@ class UserStatsRepository:
qry = qry.group_by(Message.user_id)
return qry
def _update_stats_internal(self, time_frame_key: str, base_date: Optional[datetime] = None):
def _update_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
# gather user data
time_frame_key = time_frame.value
stats_by_user: dict[UUID, UserStats] = dict()
now = utcnow()
@@ -194,8 +217,46 @@ class UserStatsRepository:
self.session.add_all(stats_by_user.values())
self.session.flush()
def update_stats_time_frame(self, time_frame_key: str, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame_key, reference_time)
self.update_ranks(time_frame=time_frame)
@log_timing(log_kwargs=True)
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
"""
Update user_stats ranks. The persisted rank values allow to
quickly the rank of a single user and to query nearby users.
"""
# todo: convert sql to sqlalchemy query..
# ranks = self.session.query(
# func.row_number()
# .over(partition_by=UserStats.time_frame, order_by=[UserStats.leader_score.desc(), UserStats.user_id])
# .label("rank"),
# UserStats.user_id,
# UserStats.time_frame,
# )
sql_update_rank = """
-- update rank
UPDATE user_stats us
SET "rank" = r."rank"
FROM
(SELECT
ROW_NUMBER () OVER(
PARTITION BY time_frame
ORDER BY leader_score DESC, user_id
) AS "rank", user_id, time_frame
FROM user_stats
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
us.user_id = r.user_id
AND us.time_frame = r.time_frame;"""
r = self.session.execute(
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
@@ -204,20 +265,20 @@ class UserStatsRepository:
match time_frame:
case UserStatsTimeFrame.day:
r = now - timedelta(days=1)
self.update_stats_time_frame(time_frame.value, r)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.week:
r = now.date() - timedelta(days=7)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame.value, r)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.month:
r = now.date() - timedelta(days=30)
r = datetime(r.year, r.month, r.day, tzinfo=now.tzinfo)
self.update_stats_time_frame(time_frame.value, r)
self.update_stats_time_frame(time_frame, r)
case UserStatsTimeFrame.total:
self.update_stats_time_frame(time_frame.value, None)
self.update_stats_time_frame(time_frame, None)
@log_timing(level="INFO")
def update_multiple_time_frames(self, time_frames: list[UserStatsTimeFrame]):
@@ -236,5 +297,7 @@ if __name__ == "__main__":
with Session(engine) as session:
api_client = get_dummy_api_client(session)
usr = UserStatsRepository(session)
usr.update_all_time_frames()
usr.get_leader_board(UserStatsTimeFrame.total)
# usr.update_all_time_frames()
# session.commit()
# usr.get_leader_board(UserStatsTimeFrame.total)
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
+16 -16
View File
@@ -358,32 +358,32 @@ class SystemStats(BaseModel):
class UserScore(BaseModel):
user_rank: int
rank: Optional[int]
user_id: UUID
username: str
auth_method: str
display_name: str
leader_score: int
leader_score: int = 0
base_date: Optional[datetime]
modified_date: datetime
modified_date: Optional[datetime]
prompts: int
replies_assistant: int
replies_prompter: int
labels_simple: int
labels_full: int
rankings_total: int
rankings_good: int
prompts: int = 0
replies_assistant: int = 0
replies_prompter: int = 0
labels_simple: int = 0
labels_full: int = 0
rankings_total: int = 0
rankings_good: int = 0
accepted_prompts: int
accepted_replies_assistant: int
accepted_replies_prompter: int
accepted_prompts: int = 0
accepted_replies_assistant: int = 0
accepted_replies_prompter: int = 0
reply_ranked_1: int
reply_ranked_2: int
reply_ranked_3: int
reply_ranked_1: int = 0
reply_ranked_2: int = 0
reply_ranked_3: int = 0
# only used for time frame "total"
streak_last_day_date: Optional[datetime]