mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
Add endpoint to query nearby leaderboard rows (#1038)
* add is_current_user bool * add user leaderboard surrounding window function
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
@@ -15,11 +17,17 @@ router = APIRouter()
|
||||
def get_leaderboard(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
current_user_id: UUID | None = None
|
||||
if frontend_user.username:
|
||||
ur = UserRepository(db, api_client)
|
||||
current_user = ur.query_frontend_user(auth_method=frontend_user.auth_method, username=frontend_user.username)
|
||||
current_user_id = current_user.id
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count, highlighted_user_id=current_user_id)
|
||||
|
||||
|
||||
@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -308,3 +308,17 @@ def query_user_stats_timeframe(
|
||||
):
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_user_stats_all_time_frames(user_id=user_id)[time_frame.value]
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats/{time_frame}/window", response_model=protocol.LeaderboardStats | None)
|
||||
def query_user_stats_timeframe_window(
|
||||
user_id: UUID,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
window_size: Optional[int] = Query(5, gt=0, le=100),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> protocol.LeaderboardStats | None:
|
||||
ur = UserRepository(db, api_client=api_client)
|
||||
user = ur.get_user(id=user_id)
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leaderboard_user_window(user=user, time_frame=time_frame, window_size=window_size)
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlalchemy.dialects import postgresql
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r):
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
@@ -32,6 +32,8 @@ def _create_user_score(r):
|
||||
"last_activity_date",
|
||||
]:
|
||||
d[k] = r[k]
|
||||
if highlighted_user_id:
|
||||
d["highlighted"] = r["user_id"] == highlighted_user_id
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
@@ -39,7 +41,12 @@ class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_leaderboard(self, time_frame: UserStatsTimeFrame, limit: int = 100) -> LeaderboardStats:
|
||||
def get_leaderboard(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
limit: int = 100,
|
||||
highlighted_user_id: Optional[UUID] = None,
|
||||
) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for the specified time frame
|
||||
"""
|
||||
@@ -61,7 +68,49 @@ class UserStatsRepository:
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
|
||||
leaderboard = [_create_user_score(r, highlighted_user_id) for r in self.session.exec(qry)]
|
||||
if len(leaderboard) > 0:
|
||||
last_update = max(x.modified_date for x in leaderboard)
|
||||
else:
|
||||
last_update = utcnow()
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard, last_updated=last_update)
|
||||
|
||||
def get_leaderboard_user_window(
|
||||
self,
|
||||
user: User,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
window_size: int = 5,
|
||||
) -> LeaderboardStats | None:
|
||||
# no window for users who don't show themselves
|
||||
if not user.show_on_leaderboard:
|
||||
return None
|
||||
|
||||
qry = self.session.query(UserStats).filter(UserStats.user_id == user.id, UserStats.time_frame == time_frame)
|
||||
stats: UserStats = qry.one_or_none()
|
||||
if stats is None or stats.rank is None:
|
||||
return None
|
||||
|
||||
min_rank = max(0, stats.rank - window_size // 2)
|
||||
max_rank = min_rank + window_size
|
||||
|
||||
qry = (
|
||||
self.session.query(
|
||||
User.id.label("user_id"),
|
||||
User.username,
|
||||
User.auth_method,
|
||||
User.display_name,
|
||||
User.streak_days,
|
||||
User.streak_last_day_date,
|
||||
User.last_activity_date,
|
||||
UserStats,
|
||||
)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
|
||||
.where(UserStats.rank >= min_rank, UserStats.rank <= max_rank)
|
||||
.order_by(UserStats.rank)
|
||||
)
|
||||
|
||||
leaderboard = [_create_user_score(r, highlighted_user_id=user.id) for r in self.session.exec(qry)]
|
||||
if len(leaderboard) > 0:
|
||||
last_update = max(x.modified_date for x in leaderboard)
|
||||
else:
|
||||
@@ -79,9 +128,9 @@ class UserStatsRepository:
|
||||
for r in self.session.exec(qry):
|
||||
us = r["UserStats"]
|
||||
if us is not None:
|
||||
stats_by_timeframe[us.time_frame] = _create_user_score(r)
|
||||
stats_by_timeframe[us.time_frame] = _create_user_score(r, user_id)
|
||||
else:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r) for tf in UserStatsTimeFrame}
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
|
||||
@@ -26,7 +26,7 @@ class TaskRequestType(str, enum.Enum):
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
display_name: str
|
||||
auth_method: Literal["discord", "local"]
|
||||
auth_method: Literal["discord", "local", "system"]
|
||||
|
||||
|
||||
class FrontEndUser(User):
|
||||
@@ -432,6 +432,7 @@ class SystemStats(BaseModel):
|
||||
class UserScore(BaseModel):
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
highlighted: bool = False
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
|
||||
Reference in New Issue
Block a user