diff --git a/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py new file mode 100644 index 00000000..19b497fa --- /dev/null +++ b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py @@ -0,0 +1,26 @@ +"""add ix_user_display_name_id + +Revision ID: 4f26fec4d204 +Revises: 0964ac95170d +Create Date: 2023-01-19 22:00:00 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4f26fec4d204" +down_revision = "7f0a28a156f4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_user_display_name_id", table_name="user") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 0b2db515..f2fc3181 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/", response_model=list[protocol.FrontEndUser]) -def get_users( +@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True) +def get_users_ordered_by_username( api_client_id: Optional[UUID] = None, - max_count: Optional[int] = Query(100, gt=0, le=10000), - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): ur = UserRepository(db, api_client) - users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method) - return [u.to_protocol_frontend_user() for u in users] - - -@router.get("/by_display_name") -def query_frontend_users_by_display_name( - search_text: str, - exact: bool = False, - api_client_id: UUID = None, - max_count: int = Query(20, gt=0, le=1000), - auth_method: str = None, - api_client: ApiClient = Depends(deps.get_api_client), - db: Session = Depends(deps.get_db), -): - ur = UserRepository(db, api_client) - users = ur.query_users_by_display_name( - search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, ) return [u.to_protocol_frontend_user() for u in users] diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 36cd65c9..0b31495a 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/users/{user_id}", response_model=protocol.FrontEndUser) +@router.get("/by_username", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_username( + api_client_id: Optional[UUID] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, + auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/by_display_name", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_display_name( + api_client_id: Optional[UUID] = None, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + search_text: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_display_name( + api_client_id=api_client_id, + gte_display_name=gte_display_name, + gt_id=gt_id, + lte_display_name=lte_display_name, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/{user_id}", response_model=protocol.FrontEndUser) def get_user( user_id: UUID, api_client_id: UUID = None, @@ -31,7 +85,7 @@ def get_user( return user.to_protocol_frontend_user() -@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT) def update_user( user_id: UUID, enabled: Optional[bool] = None, @@ -46,7 +100,7 @@ def update_user( ur.update_user(user_id, enabled, notes) -@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT) def delete_user( user_id: UUID, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 0fb36c22..d882a15a 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel class User(SQLModel, table=True): __tablename__ = "user" - __table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),) + __table_args__ = ( + Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True), + Index("ix_user_display_name_id", "display_name", "id", unique=True), + ) id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 578dc5f1..c0c2a88d 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, and_, or_ from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -135,13 +135,16 @@ class UserRepository: self.db.add(user) return user - def query_users( + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None, - limit: Optional[int] = 20, - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -150,34 +153,52 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - users = self.db.query(User) + qry = self.db.query(User).order_by(User.username, User.id) - if api_client_id: - users = users.filter(User.api_client_id == api_client_id) + if gte_username is not None: + if gt_id: + qry = qry.filter( + or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id)) + ) + else: + qry = qry.filter(User.username >= gte_username) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_username is not None: + if lt_id: + qry = qry.filter( + or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id)) + ) + else: + qry = qry.filter(User.username <= lte_username) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) if auth_method: - users = users.filter(User.auth_method == auth_method) + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) - users = users.order_by(User.display_name) - - if gt: - users = users.filter(User.display_name > gt) - - if lt: - users = users.filter(User.display_name < lt) + if search_text: + pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) + qry = qry.filter(User.username.like(pattern)) if limit is not None: - users = users.limit(limit) + qry = qry.limit(limit) - return users.all() + return qry.all() - def query_users_by_display_name( + def query_users_ordered_by_display_name( self, - search_text: str, - exact: Optional[bool] = False, - limit: Optional[int] = 20, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, api_client_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -186,11 +207,40 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - qry = self.db.query(User).order_by(User.display_name) + qry = self.db.query(User).order_by(User.display_name, User.id) - if exact: - qry = qry.filter(User.display_name == search_text) - else: + if gte_display_name is not None: + if gt_id: + qry = qry.filter( + or_( + User.display_name > gte_display_name, + and_(User.display_name == gte_display_name, User.id > gt_id), + ) + ) + else: + qry = qry.filter(User.display_name >= gte_display_name) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_display_name is not None: + if lt_id: + qry = qry.filter( + or_( + User.display_name < lte_display_name, + and_(User.display_name == lte_display_name, User.id < lt_id), + ) + ) + else: + qry = qry.filter(User.display_name <= lte_display_name) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if auth_method: + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) + + if search_text: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.display_name.like(pattern))