mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add keyset pagination for users ordered by username / display_name (#851)
* add keyset pagination for user ordered by username or display_name * add index on display-name for user table * update down_revision in migration script
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""add ix_user_display_name_id
|
||||
|
||||
Revision ID: 4f26fec4d204
|
||||
Revises: 0964ac95170d
|
||||
Create Date: 2023-01-19 22:00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4f26fec4d204"
|
||||
down_revision = "7f0a28a156f4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_user_display_name_id", table_name="user")
|
||||
# ### end Alembic commands ###
|
||||
@@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[protocol.FrontEndUser])
|
||||
def get_users(
|
||||
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
|
||||
def get_users_ordered_by_username(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/by_display_name")
|
||||
def query_frontend_users_by_display_name(
|
||||
search_text: str,
|
||||
exact: bool = False,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(20, gt=0, le=1000),
|
||||
auth_method: str = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_by_display_name(
|
||||
search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method
|
||||
users = ur.query_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
gt_id=gt_id,
|
||||
lte_username=lte_username,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=protocol.FrontEndUser)
|
||||
@router.get("/by_username", response_model=list[protocol.FrontEndUser])
|
||||
def get_users_ordered_by_username(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
gt_id=gt_id,
|
||||
lte_username=lte_username,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/by_display_name", response_model=list[protocol.FrontEndUser])
|
||||
def get_users_ordered_by_display_name(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
gte_display_name: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_display_name: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_ordered_by_display_name(
|
||||
api_client_id=api_client_id,
|
||||
gte_display_name=gte_display_name,
|
||||
gt_id=gt_id,
|
||||
lte_display_name=lte_display_name,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=protocol.FrontEndUser)
|
||||
def get_user(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
@@ -31,7 +85,7 @@ def get_user(
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
@@ -46,7 +100,7 @@ def update_user(
|
||||
ur.update_user(user_id, enabled, notes)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "user"
|
||||
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
__table_args__ = (
|
||||
Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),
|
||||
Index("ix_user_display_name_id", "display_name", "id", unique=True),
|
||||
)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
|
||||
@@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, and_, or_
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -135,13 +135,16 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
def query_users(
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = 20,
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -150,34 +153,52 @@ class UserRepository:
|
||||
if api_client_id != self.api_client.id:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
users = self.db.query(User)
|
||||
qry = self.db.query(User).order_by(User.username, User.id)
|
||||
|
||||
if api_client_id:
|
||||
users = users.filter(User.api_client_id == api_client_id)
|
||||
if gte_username is not None:
|
||||
if gt_id:
|
||||
qry = qry.filter(
|
||||
or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id))
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.username >= gte_username)
|
||||
elif gt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if lte_username is not None:
|
||||
if lt_id:
|
||||
qry = qry.filter(
|
||||
or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id))
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.username <= lte_username)
|
||||
elif lt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if auth_method:
|
||||
users = users.filter(User.auth_method == auth_method)
|
||||
qry = qry.filter(User.auth_method == auth_method)
|
||||
if api_client_id:
|
||||
qry = qry.filter(User.api_client_id == api_client_id)
|
||||
|
||||
users = users.order_by(User.display_name)
|
||||
|
||||
if gt:
|
||||
users = users.filter(User.display_name > gt)
|
||||
|
||||
if lt:
|
||||
users = users.filter(User.display_name < lt)
|
||||
if search_text:
|
||||
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
|
||||
qry = qry.filter(User.username.like(pattern))
|
||||
|
||||
if limit is not None:
|
||||
users = users.limit(limit)
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return users.all()
|
||||
return qry.all()
|
||||
|
||||
def query_users_by_display_name(
|
||||
def query_users_ordered_by_display_name(
|
||||
self,
|
||||
search_text: str,
|
||||
exact: Optional[bool] = False,
|
||||
limit: Optional[int] = 20,
|
||||
gte_display_name: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_display_name: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -186,11 +207,40 @@ class UserRepository:
|
||||
if api_client_id != self.api_client.id:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
qry = self.db.query(User).order_by(User.display_name)
|
||||
qry = self.db.query(User).order_by(User.display_name, User.id)
|
||||
|
||||
if exact:
|
||||
qry = qry.filter(User.display_name == search_text)
|
||||
else:
|
||||
if gte_display_name is not None:
|
||||
if gt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
User.display_name > gte_display_name,
|
||||
and_(User.display_name == gte_display_name, User.id > gt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.display_name >= gte_display_name)
|
||||
elif gt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if lte_display_name is not None:
|
||||
if lt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
User.display_name < lte_display_name,
|
||||
and_(User.display_name == lte_display_name, User.id < lt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.display_name <= lte_display_name)
|
||||
elif lt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if auth_method:
|
||||
qry = qry.filter(User.auth_method == auth_method)
|
||||
if api_client_id:
|
||||
qry = qry.filter(User.api_client_id == api_client_id)
|
||||
|
||||
if search_text:
|
||||
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
|
||||
qry = qry.filter(User.display_name.like(pattern))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user