Add endpoint to query nearby leaderboard rows (#1038)

* add is_current_user bool

* add user leaderboard surrounding window function
This commit is contained in:
Andreas Köpf
2023-01-31 15:05:05 +01:00
committed by GitHub
parent 355d621488
commit b6bdb84019
4 changed files with 79 additions and 7 deletions
+9 -1
View File
@@ -1,8 +1,10 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_repository import UserRepository
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
@@ -15,11 +17,17 @@ router = APIRouter()
def get_leaderboard(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
) -> LeaderboardStats:
current_user_id: UUID | None = None
if frontend_user.username:
ur = UserRepository(db, api_client)
current_user = ur.query_frontend_user(auth_method=frontend_user.auth_method, username=frontend_user.username)
current_user_id = current_user.id
usr = UserStatsRepository(db)
return usr.get_leaderboard(time_frame, limit=max_count)
return usr.get_leaderboard(time_frame, limit=max_count, highlighted_user_id=current_user_id)
@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
+14
View File
@@ -308,3 +308,17 @@ def query_user_stats_timeframe(
):
usr = UserStatsRepository(db)
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
@router.get("/{user_id}/stats/{time_frame}/window", response_model=protocol.LeaderboardStats | None)
def query_user_stats_timeframe_window(
user_id: UUID,
time_frame: UserStatsTimeFrame,
window_size: Optional[int] = Query(5, gt=0, le=100),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
) -> protocol.LeaderboardStats | None:
ur = UserRepository(db, api_client=api_client)
user = ur.get_user(id=user_id)
usr = UserStatsRepository(db)
return usr.get_leaderboard_user_window(user=user, time_frame=time_frame, window_size=window_size)
+54 -5
View File
@@ -17,7 +17,7 @@ from sqlalchemy.dialects import postgresql
from sqlmodel import Session, delete, func, text
def _create_user_score(r):
def _create_user_score(r, highlighted_user_id: UUID | None):
if r["UserStats"]:
d = r["UserStats"].dict()
else:
@@ -32,6 +32,8 @@ def _create_user_score(r):
"last_activity_date",
]:
d[k] = r[k]
if highlighted_user_id:
d["highlighted"] = r["user_id"] == highlighted_user_id
return UserScore(**d)
@@ -39,7 +41,12 @@ class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
def get_leaderboard(
self,
time_frame: UserStatsTimeFrame,
limit: int = 100,
highlighted_user_id: Optional[UUID] = None,
) -> LeaderboardStats:
"""
Get leaderboard stats for the specified time frame
"""
@@ -61,7 +68,49 @@ class UserStatsRepository:
.limit(limit)
)
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
leaderboard = [_create_user_score(r, highlighted_user_id) for r in self.session.exec(qry)]
if len(leaderboard) > 0:
last_update = max(x.modified_date for x in leaderboard)
else:
last_update = utcnow()
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard, last_updated=last_update)
def get_leaderboard_user_window(
self,
user: User,
time_frame: UserStatsTimeFrame,
window_size: int = 5,
) -> LeaderboardStats | None:
# no window for users who don't show themselves
if not user.show_on_leaderboard:
return None
qry = self.session.query(UserStats).filter(UserStats.user_id == user.id, UserStats.time_frame == time_frame)
stats: UserStats = qry.one_or_none()
if stats is None or stats.rank is None:
return None
min_rank = max(0, stats.rank - window_size // 2)
max_rank = min_rank + window_size
qry = (
self.session.query(
User.id.label("user_id"),
User.username,
User.auth_method,
User.display_name,
User.streak_days,
User.streak_last_day_date,
User.last_activity_date,
UserStats,
)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
.where(UserStats.rank >= min_rank, UserStats.rank <= max_rank)
.order_by(UserStats.rank)
)
leaderboard = [_create_user_score(r, highlighted_user_id=user.id) for r in self.session.exec(qry)]
if len(leaderboard) > 0:
last_update = max(x.modified_date for x in leaderboard)
else:
@@ -79,9 +128,9 @@ class UserStatsRepository:
for r in self.session.exec(qry):
us = r["UserStats"]
if us is not None:
stats_by_timeframe[us.time_frame] = _create_user_score(r)
stats_by_timeframe[us.time_frame] = _create_user_score(r, user_id)
else:
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def query_total_prompts_per_user(
@@ -26,7 +26,7 @@ class TaskRequestType(str, enum.Enum):
class User(BaseModel):
id: str
display_name: str
auth_method: Literal["discord", "local"]
auth_method: Literal["discord", "local", "system"]
class FrontEndUser(User):
@@ -432,6 +432,7 @@ class SystemStats(BaseModel):
class UserScore(BaseModel):
rank: Optional[int]
user_id: UUID
highlighted: bool = False
username: str
auth_method: str
display_name: str