Add keyset pagination for users ordered by username / display_name (#851)

* add keyset pagination for user ordered by username or display_name

* add index on display-name for user table

* update down_revision in migration script
This commit is contained in:
Andreas Köpf
2023-01-20 16:32:13 +01:00
committed by GitHub
parent 7488f175f2
commit 70fc80aa08
5 changed files with 180 additions and 52 deletions
@@ -0,0 +1,26 @@
"""add ix_user_display_name_id
Revision ID: 4f26fec4d204
Revises: 0964ac95170d
Create Date: 2023-01-19 22:00:00
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4f26fec4d204"
down_revision = "7f0a28a156f4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_user_display_name_id", table_name="user")
# ### end Alembic commands ###
+17 -22
View File
@@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/", response_model=list[protocol.FrontEndUser])
def get_users(
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
def get_users_ordered_by_username(
api_client_id: Optional[UUID] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
gt: Optional[str] = None,
lt: Optional[str] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
search_text: Optional[str] = None,
auth_method: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/by_display_name")
def query_frontend_users_by_display_name(
search_text: str,
exact: bool = False,
api_client_id: UUID = None,
max_count: int = Query(20, gt=0, le=1000),
auth_method: str = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_by_display_name(
search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method
users = ur.query_users_ordered_by_username(
api_client_id=api_client_id,
gte_username=gte_username,
gt_id=gt_id,
lte_username=lte_username,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
+57 -3
View File
@@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/users/{user_id}", response_model=protocol.FrontEndUser)
@router.get("/by_username", response_model=list[protocol.FrontEndUser])
def get_users_ordered_by_username(
api_client_id: Optional[UUID] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
search_text: Optional[str] = None,
auth_method: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_ordered_by_username(
api_client_id=api_client_id,
gte_username=gte_username,
gt_id=gt_id,
lte_username=lte_username,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/by_display_name", response_model=list[protocol.FrontEndUser])
def get_users_ordered_by_display_name(
api_client_id: Optional[UUID] = None,
gte_display_name: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_display_name: Optional[str] = None,
lt_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_ordered_by_display_name(
api_client_id=api_client_id,
gte_display_name=gte_display_name,
gt_id=gt_id,
lte_display_name=lte_display_name,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/{user_id}", response_model=protocol.FrontEndUser)
def get_user(
user_id: UUID,
api_client_id: UUID = None,
@@ -31,7 +85,7 @@ def get_user(
return user.to_protocol_frontend_user()
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT)
def update_user(
user_id: UUID,
enabled: Optional[bool] = None,
@@ -46,7 +100,7 @@ def update_user(
ur.update_user(user_id, enabled, notes)
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
def delete_user(
user_id: UUID,
db: Session = Depends(deps.get_db),
+4 -1
View File
@@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel
class User(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
__table_args__ = (
Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),
Index("ix_user_display_name_id", "display_name", "id", unique=True),
)
id: Optional[UUID] = Field(
sa_column=sa.Column(
+76 -26
View File
@@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from sqlmodel import Session, and_, or_
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -135,13 +135,16 @@ class UserRepository:
self.db.add(user)
return user
def query_users(
def query_users_ordered_by_username(
self,
api_client_id: Optional[UUID] = None,
limit: Optional[int] = 20,
gt: Optional[str] = None,
lt: Optional[str] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
limit: Optional[int] = 100,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
@@ -150,34 +153,52 @@ class UserRepository:
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
users = self.db.query(User)
qry = self.db.query(User).order_by(User.username, User.id)
if api_client_id:
users = users.filter(User.api_client_id == api_client_id)
if gte_username is not None:
if gt_id:
qry = qry.filter(
or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id))
)
else:
qry = qry.filter(User.username >= gte_username)
elif gt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if lte_username is not None:
if lt_id:
qry = qry.filter(
or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id))
)
else:
qry = qry.filter(User.username <= lte_username)
elif lt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if auth_method:
users = users.filter(User.auth_method == auth_method)
qry = qry.filter(User.auth_method == auth_method)
if api_client_id:
qry = qry.filter(User.api_client_id == api_client_id)
users = users.order_by(User.display_name)
if gt:
users = users.filter(User.display_name > gt)
if lt:
users = users.filter(User.display_name < lt)
if search_text:
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
qry = qry.filter(User.username.like(pattern))
if limit is not None:
users = users.limit(limit)
qry = qry.limit(limit)
return users.all()
return qry.all()
def query_users_by_display_name(
def query_users_ordered_by_display_name(
self,
search_text: str,
exact: Optional[bool] = False,
limit: Optional[int] = 20,
gte_display_name: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_display_name: Optional[str] = None,
lt_id: Optional[UUID] = None,
api_client_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
limit: Optional[int] = 100,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
@@ -186,11 +207,40 @@ class UserRepository:
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
qry = self.db.query(User).order_by(User.display_name)
qry = self.db.query(User).order_by(User.display_name, User.id)
if exact:
qry = qry.filter(User.display_name == search_text)
else:
if gte_display_name is not None:
if gt_id:
qry = qry.filter(
or_(
User.display_name > gte_display_name,
and_(User.display_name == gte_display_name, User.id > gt_id),
)
)
else:
qry = qry.filter(User.display_name >= gte_display_name)
elif gt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if lte_display_name is not None:
if lt_id:
qry = qry.filter(
or_(
User.display_name < lte_display_name,
and_(User.display_name == lte_display_name, User.id < lt_id),
)
)
else:
qry = qry.filter(User.display_name <= lte_display_name)
elif lt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if auth_method:
qry = qry.filter(User.auth_method == auth_method)
if api_client_id:
qry = qry.filter(User.api_client_id == api_client_id)
if search_text:
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
qry = qry.filter(User.display_name.like(pattern))