diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 1caf1f46..8865ea0c 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -15,6 +15,21 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() +@router.get("/", response_model=list[protocol.User]) +def get_users( + api_client_id: UUID = None, + max_count: int = Query(10, gt=0, le=20), # TODO: refine bounds + gte: str = None, + lt: str = None, + auth_method: str = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + pr = UserRepository(db, api_client) + users = pr.query_users(api_client_id=api_client_id, limit=max_count, gte=gte, lt=lt, auth_method=auth_method) + return [u.to_protocol_user() for u in users] + + @router.get("/{auth_method}/{username}", response_model=protocol.FrontEndUser) def query_frontend_user( auth_method: str, diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 98cca465..8c49eaee 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -4,6 +4,7 @@ from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg +from oasst_shared.schemas import protocol from sqlmodel import AutoString, Field, Index, SQLModel @@ -26,3 +27,6 @@ class User(SQLModel, table=True): enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true())) notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default="''")) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false())) + + def to_protocol_user(self): + return protocol.User(id=self.username, display_name=self.display_name, auth_method=self.auth_method) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 14a15b35..63967d17 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -157,3 +157,39 @@ class UserRepository: ] return LeaderboardStats(leaderboard=result) + + def query_users( + self, + api_client_id: Optional[UUID] = None, + limit: Optional[int] = 20, + gte: Optional[str] = None, + lt: Optional[str] = None, + auth_method: Optional[str] = None, + ) -> list[User]: + if not self.api_client.trusted: + if not api_client_id: + api_client_id = self.api_client.id + + if api_client_id != self.api_client.id: + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + + users = self.db.query(User) + + if api_client_id: + users = users.filter(User.api_client_id == api_client_id) + + if auth_method: + users = users.filter(User.auth_method == auth_method) + + users = users.order_by(User.display_name) + + if gte: + users = users.filter(User.display_name >= gte) + + if lt: + users = users.filter(User.display_name < lt) + + if limit is not None: + users = users.limit(limit) + + return users.all()