mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
537: Endpoint to list frontend users (#554)
* added frontend users endpoint * fix comparison * added api_client_id filtration * allow untrusted api-clients * review fixes Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
This commit is contained in:
@@ -15,6 +15,21 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[protocol.User])
|
||||
def get_users(
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=20), # TODO: refine bounds
|
||||
gte: str = None,
|
||||
lt: str = None,
|
||||
auth_method: str = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
pr = UserRepository(db, api_client)
|
||||
users = pr.query_users(api_client_id=api_client_id, limit=max_count, gte=gte, lt=lt, auth_method=auth_method)
|
||||
return [u.to_protocol_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}", response_model=protocol.FrontEndUser)
|
||||
def query_frontend_user(
|
||||
auth_method: str,
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import AutoString, Field, Index, SQLModel
|
||||
|
||||
|
||||
@@ -26,3 +27,6 @@ class User(SQLModel, table=True):
|
||||
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
notes: str = Field(sa_column=sa.Column(AutoString(length=1024), nullable=False, server_default="''"))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
|
||||
|
||||
def to_protocol_user(self):
|
||||
return protocol.User(id=self.username, display_name=self.display_name, auth_method=self.auth_method)
|
||||
|
||||
@@ -157,3 +157,39 @@ class UserRepository:
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
def query_users(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = 20,
|
||||
gte: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if api_client_id != self.api_client.id:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
users = self.db.query(User)
|
||||
|
||||
if api_client_id:
|
||||
users = users.filter(User.api_client_id == api_client_id)
|
||||
|
||||
if auth_method:
|
||||
users = users.filter(User.auth_method == auth_method)
|
||||
|
||||
users = users.order_by(User.display_name)
|
||||
|
||||
if gte:
|
||||
users = users.filter(User.display_name >= gte)
|
||||
|
||||
if lt:
|
||||
users = users.filter(User.display_name < lt)
|
||||
|
||||
if limit is not None:
|
||||
users = users.limit(limit)
|
||||
|
||||
return users.all()
|
||||
|
||||
Reference in New Issue
Block a user