537: Endpoint to list frontend users (#554)

* added frontend users endpoint

* fix comparison

* added api_client_id filtration

* allow untrusted api-clients

* review fixes

Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
This commit is contained in:
Vechtomov
2023-01-13 00:45:11 +03:00
committed by GitHub
parent f264b43cde
commit 0d646e72f3
3 changed files with 55 additions and 0 deletions
@@ -15,6 +15,21 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/", response_model=list[protocol.User])
def get_users(
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=20), # TODO: refine bounds
gte: str = None,
lt: str = None,
auth_method: str = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
pr = UserRepository(db, api_client)
users = pr.query_users(api_client_id=api_client_id, limit=max_count, gte=gte, lt=lt, auth_method=auth_method)
return [u.to_protocol_user() for u in users]
@router.get("/{auth_method}/{username}", response_model=protocol.FrontEndUser)
def query_frontend_user(
auth_method: str,
+4
View File
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_shared.schemas import protocol
from sqlmodel import AutoString, Field, Index, SQLModel
@@ -26,3 +27,6 @@ class User(SQLModel, table=True):
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default="''"))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
def to_protocol_user(self):
return protocol.User(id=self.username, display_name=self.display_name, auth_method=self.auth_method)
+36
View File
@@ -157,3 +157,39 @@ class UserRepository:
]
return LeaderboardStats(leaderboard=result)
def query_users(
self,
api_client_id: Optional[UUID] = None,
limit: Optional[int] = 20,
gte: Optional[str] = None,
lt: Optional[str] = None,
auth_method: Optional[str] = None,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
api_client_id = self.api_client.id
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
users = self.db.query(User)
if api_client_id:
users = users.filter(User.api_client_id == api_client_id)
if auth_method:
users = users.filter(User.auth_method == auth_method)
users = users.order_by(User.display_name)
if gte:
users = users.filter(User.display_name >= gte)
if lt:
users = users.filter(User.display_name < lt)
if limit is not None:
users = users.limit(limit)
return users.all()