diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 27366475..6c21bb9e 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -1,8 +1,10 @@ from typing import Optional +from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.models import ApiClient +from oasst_backend.user_repository import UserRepository from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session @@ -15,11 +17,17 @@ router = APIRouter() def get_leaderboard( time_frame: UserStatsTimeFrame, max_count: Optional[int] = Query(100, gt=0, le=10000), + frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id), api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ) -> LeaderboardStats: + current_user_id: UUID | None = None + if frontend_user.username: + ur = UserRepository(db, api_client) + current_user = ur.query_frontend_user(auth_method=frontend_user.auth_method, username=frontend_user.username) + current_user_id = current_user.id usr = UserStatsRepository(db) - return usr.get_leaderboard(time_frame, limit=max_count) + return usr.get_leaderboard(time_frame, limit=max_count, highlighted_user_id=current_user_id) @router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index dc0a3242..2ced40c1 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -308,3 +308,17 @@ def query_user_stats_timeframe( ): usr = UserStatsRepository(db) return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value] + + +@router.get("/{user_id}/stats/{time_frame}/window", response_model=protocol.LeaderboardStats | None) +def query_user_stats_timeframe_window( + user_id: UUID, + time_frame: UserStatsTimeFrame, + window_size: Optional[int] = Query(5, gt=0, le=100), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +) -> protocol.LeaderboardStats | None: + ur = UserRepository(db, api_client=api_client) + user = ur.get_user(id=user_id) + usr = UserStatsRepository(db) + return usr.get_leaderboard_user_window(user=user, time_frame=time_frame, window_size=window_size) diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 49a4d7b1..3862d098 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -17,7 +17,7 @@ from sqlalchemy.dialects import postgresql from sqlmodel import Session, delete, func, text -def _create_user_score(r): +def _create_user_score(r, highlighted_user_id: UUID | None): if r["UserStats"]: d = r["UserStats"].dict() else: @@ -32,6 +32,8 @@ def _create_user_score(r): "last_activity_date", ]: d[k] = r[k] + if highlighted_user_id: + d["highlighted"] = r["user_id"] == highlighted_user_id return UserScore(**d) @@ -39,7 +41,12 @@ class UserStatsRepository: def __init__(self, session: Session): self.session = session - def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats: + def get_leaderboard( + self, + time_frame: UserStatsTimeFrame, + limit: int = 100, + highlighted_user_id: Optional[UUID] = None, + ) -> LeaderboardStats: """ Get leaderboard stats for the specified time frame """ @@ -61,7 +68,49 @@ class UserStatsRepository: .limit(limit) ) - leaderboard = [_create_user_score(r) for r in self.session.exec(qry)] + leaderboard = [_create_user_score(r, highlighted_user_id) for r in self.session.exec(qry)] + if len(leaderboard) > 0: + last_update = max(x.modified_date for x in leaderboard) + else: + last_update = utcnow() + return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard, last_updated=last_update) + + def get_leaderboard_user_window( + self, + user: User, + time_frame: UserStatsTimeFrame, + window_size: int = 5, + ) -> LeaderboardStats | None: + # no window for users who don't show themselves + if not user.show_on_leaderboard: + return None + + qry = self.session.query(UserStats).filter(UserStats.user_id == user.id, UserStats.time_frame == time_frame) + stats: UserStats = qry.one_or_none() + if stats is None or stats.rank is None: + return None + + min_rank = max(0, stats.rank - window_size // 2) + max_rank = min_rank + window_size + + qry = ( + self.session.query( + User.id.label("user_id"), + User.username, + User.auth_method, + User.display_name, + User.streak_days, + User.streak_last_day_date, + User.last_activity_date, + UserStats, + ) + .join(UserStats, User.id == UserStats.user_id) + .filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard) + .where(UserStats.rank >= min_rank, UserStats.rank <= max_rank) + .order_by(UserStats.rank) + ) + + leaderboard = [_create_user_score(r, highlighted_user_id=user.id) for r in self.session.exec(qry)] if len(leaderboard) > 0: last_update = max(x.modified_date for x in leaderboard) else: @@ -79,9 +128,9 @@ class UserStatsRepository: for r in self.session.exec(qry): us = r["UserStats"] if us is not None: - stats_by_timeframe[us.time_frame] = _create_user_score(r) + stats_by_timeframe[us.time_frame] = _create_user_score(r, user_id) else: - stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame} + stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame} return stats_by_timeframe def query_total_prompts_per_user( diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index c1932b27..f772eb81 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -26,7 +26,7 @@ class TaskRequestType(str, enum.Enum): class User(BaseModel): id: str display_name: str - auth_method: Literal["discord", "local"] + auth_method: Literal["discord", "local", "system"] class FrontEndUser(User): @@ -432,6 +432,7 @@ class SystemStats(BaseModel): class UserScore(BaseModel): rank: Optional[int] user_id: UUID + highlighted: bool = False username: str auth_method: str display_name: str