diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 5d96aa90..51f8cd20 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -105,8 +105,8 @@ def query_frontend_user_messages_cursor( db: Session = Depends(deps.get_db), ): return get_messages_cursor( - lt=lt, - gt=gt, + before=lt, + after=gt, auth_method=auth_method, username=username, only_roots=only_roots, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 1ef1e929..d3d5e1c3 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -50,8 +50,8 @@ def query_messages( @router.get("/cursor", response_model=protocol.MessagePage) def get_messages_cursor( - lt: Optional[str] = None, - gt: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, user_id: Optional[UUID] = None, auth_method: Optional[str] = None, username: Optional[str] = None, @@ -63,6 +63,8 @@ def get_messages_cursor( api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): + assert max_count is not None + def split_cursor(x: str | None) -> tuple[datetime, UUID]: if not x: return None, None @@ -74,11 +76,21 @@ def get_messages_cursor( except ValueError: raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE) - lte_created_date, lt_id = split_cursor(lt) - gte_created_date, gt_id = split_cursor(gt) + if desc: + gte_created_date, gt_id = split_cursor(before) + lte_created_date, lt_id = split_cursor(after) + query_desc = not (before is not None and not after) + else: + lte_created_date, lt_id = split_cursor(before) + gte_created_date, gt_id = split_cursor(after) + query_desc = before is not None and not after + + print(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}") + + qry_max_count = max_count + 1 if before is None or after is None else max_count pr = PromptRepository(db, api_client) - messages = pr.query_messages_ordered_by_created_date( + items = pr.query_messages_ordered_by_created_date( user_id=user_id, auth_method=auth_method, username=username, @@ -89,22 +101,30 @@ def get_messages_cursor( lt_id=lt_id, only_roots=only_roots, deleted=None if include_deleted else False, - desc=desc, - limit=max_count, + desc=query_desc, + limit=qry_max_count, ) - items = utils.prepare_message_list(messages) + num_rows = len(items) + if qry_max_count > max_count and num_rows == qry_max_count: + assert not (before and after) + items = items[:-1] + + if desc != query_desc: + items.reverse() + + items = utils.prepare_message_list(items) n, p = None, None if len(items) > 0: - if len(items) == max_count or gte_created_date: + if (num_rows > max_count and before) or after: p = str(items[0].id) + "$" + items[0].created_date.isoformat() - if len(items) == max_count or lte_created_date: + if num_rows > max_count or before: n = str(items[-1].id) + "$" + items[-1].created_date.isoformat() else: - if gte_created_date: - p = gte_created_date.isoformat() - if lte_created_date: - n = lte_created_date.isoformat() + if after: + p = lte_created_date.isoformat() if desc else gte_created_date.isoformat() + if before: + n = gte_created_date.isoformat() if desc else lte_created_date.isoformat() order = "desc" if desc else "asc" return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index c7ff9f9c..63bb79ea 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -100,7 +100,7 @@ def get_users_cursor( items: list[protocol.FrontEndUser] qry_max_count = max_count + 1 if lt is None or gt is None else max_count - desc = lt and not gt + desc = lt is not None and not gt def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]): p, n = None, None @@ -119,7 +119,7 @@ def get_users_cursor( def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None): num_rows = len(items) if qry_max_count > max_count and num_rows == qry_max_count: - assert not (lt and gt) + assert not (lt is not None and gt is not None) items = items[:-1] if desc: items.reverse() @@ -257,8 +257,8 @@ def query_user_messages_cursor( db: Session = Depends(deps.get_db), ): return get_messages_cursor( - lt=lt, - gt=gt, + before=lt, + after=gt, user_id=user_id, only_roots=only_roots, include_deleted=include_deleted,