diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index ed07eb21..c7ff9f9c 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -28,6 +28,7 @@ def get_users_ordered_by_username( search_text: Optional[str] = None, auth_method: Optional[str] = None, max_count: Optional[int] = Query(100, gt=0, le=10000), + desc: Optional[bool] = False, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): @@ -41,6 +42,7 @@ def get_users_ordered_by_username( auth_method=auth_method, search_text=search_text, limit=max_count, + desc=desc, ) return [u.to_protocol_frontend_user() for u in users] @@ -55,6 +57,7 @@ def get_users_ordered_by_display_name( auth_method: Optional[str] = None, search_text: Optional[str] = None, max_count: Optional[int] = Query(100, gt=0, le=10000), + desc: Optional[bool] = False, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): @@ -68,6 +71,7 @@ def get_users_ordered_by_display_name( auth_method=auth_method, search_text=search_text, limit=max_count, + desc=desc, ) return [u.to_protocol_frontend_user() for u in users] @@ -96,6 +100,7 @@ def get_users_cursor( items: list[protocol.FrontEndUser] qry_max_count = max_count + 1 if lt is None or gt is None else max_count + desc = lt and not gt def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]): p, n = None, None @@ -115,10 +120,9 @@ def get_users_cursor( num_rows = len(items) if qry_max_count > max_count and num_rows == qry_max_count: assert not (lt and gt) - if lt: - items = items[1:] - else: - items = items[:-1] + items = items[:-1] + if desc: + items.reverse() return items, num_rows n, p = None, None @@ -134,6 +138,7 @@ def get_users_cursor( auth_method=auth_method, search_text=search_text, max_count=qry_max_count, + desc=desc, api_client=api_client, db=db, ) @@ -152,6 +157,7 @@ def get_users_cursor( auth_method=auth_method, search_text=search_text, max_count=qry_max_count, + desc=desc, api_client=api_client, db=db, ) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 79df99ab..136d3d29 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -145,6 +145,7 @@ class UserRepository: auth_method: Optional[str] = None, search_text: Optional[str] = None, limit: Optional[int] = 100, + desc: bool = False, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -184,14 +185,13 @@ class UserRepository: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.username.like(pattern)) - if limit is not None and lte_username and not gte_username: - # select top rows but return results in ascernding order - sub_qry = qry.order_by(User.username.desc(), User.id.desc()).limit(limit).subquery("u") - qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.username, User.id) + if desc: + qry = qry.order_by(User.username.desc(), User.id.desc()) else: qry = qry.order_by(User.username, User.id) - if limit is not None: - qry = qry.limit(limit) + + if limit is not None: + qry = qry.limit(limit) return qry.all() @@ -205,6 +205,7 @@ class UserRepository: auth_method: Optional[str] = None, search_text: Optional[str] = None, limit: Optional[int] = 100, + desc: bool = False, ) -> list[User]: if not self.api_client.trusted: @@ -256,13 +257,12 @@ class UserRepository: if auth_method: qry = qry.filter(User.auth_method == auth_method) - if limit is not None and lte_display_name and not gte_display_name: - # select top rows but return results in ascernding order - sub_qry = qry.order_by(User.display_name.desc(), User.id.desc()).limit(limit).subquery("u") - qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.display_name, User.id) + if desc: + qry = qry.order_by(User.display_name.desc(), User.id.desc()) else: qry = qry.order_by(User.display_name, User.id) - if limit is not None: - qry = qry.limit(limit) + + if limit is not None: + qry = qry.limit(limit) return qry.all()