fix desc ordering of message cursor endpoints

This commit is contained in:
Andreas Köpf
2023-01-22 18:37:09 +01:00
parent f645b53405
commit 0b8ed52102
3 changed files with 40 additions and 20 deletions
@@ -105,8 +105,8 @@ def query_frontend_user_messages_cursor(
db: Session = Depends(deps.get_db),
):
return get_messages_cursor(
lt=lt,
gt=gt,
before=lt,
after=gt,
auth_method=auth_method,
username=username,
only_roots=only_roots,
+34 -14
View File
@@ -50,8 +50,8 @@ def query_messages(
@router.get("/cursor", response_model=protocol.MessagePage)
def get_messages_cursor(
lt: Optional[str] = None,
gt: Optional[str] = None,
before: Optional[str] = None,
after: Optional[str] = None,
user_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
username: Optional[str] = None,
@@ -63,6 +63,8 @@ def get_messages_cursor(
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
assert max_count is not None
def split_cursor(x: str | None) -> tuple[datetime, UUID]:
if not x:
return None, None
@@ -74,11 +76,21 @@ def get_messages_cursor(
except ValueError:
raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE)
lte_created_date, lt_id = split_cursor(lt)
gte_created_date, gt_id = split_cursor(gt)
if desc:
gte_created_date, gt_id = split_cursor(before)
lte_created_date, lt_id = split_cursor(after)
query_desc = not (before is not None and not after)
else:
lte_created_date, lt_id = split_cursor(before)
gte_created_date, gt_id = split_cursor(after)
query_desc = before is not None and not after
print(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}")
qry_max_count = max_count + 1 if before is None or after is None else max_count
pr = PromptRepository(db, api_client)
messages = pr.query_messages_ordered_by_created_date(
items = pr.query_messages_ordered_by_created_date(
user_id=user_id,
auth_method=auth_method,
username=username,
@@ -89,22 +101,30 @@ def get_messages_cursor(
lt_id=lt_id,
only_roots=only_roots,
deleted=None if include_deleted else False,
desc=desc,
limit=max_count,
desc=query_desc,
limit=qry_max_count,
)
items = utils.prepare_message_list(messages)
num_rows = len(items)
if qry_max_count > max_count and num_rows == qry_max_count:
assert not (before and after)
items = items[:-1]
if desc != query_desc:
items.reverse()
items = utils.prepare_message_list(items)
n, p = None, None
if len(items) > 0:
if len(items) == max_count or gte_created_date:
if (num_rows > max_count and before) or after:
p = str(items[0].id) + "$" + items[0].created_date.isoformat()
if len(items) == max_count or lte_created_date:
if num_rows > max_count or before:
n = str(items[-1].id) + "$" + items[-1].created_date.isoformat()
else:
if gte_created_date:
p = gte_created_date.isoformat()
if lte_created_date:
n = lte_created_date.isoformat()
if after:
p = lte_created_date.isoformat() if desc else gte_created_date.isoformat()
if before:
n = gte_created_date.isoformat() if desc else lte_created_date.isoformat()
order = "desc" if desc else "asc"
return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items)
+4 -4
View File
@@ -100,7 +100,7 @@ def get_users_cursor(
items: list[protocol.FrontEndUser]
qry_max_count = max_count + 1 if lt is None or gt is None else max_count
desc = lt and not gt
desc = lt is not None and not gt
def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]):
p, n = None, None
@@ -119,7 +119,7 @@ def get_users_cursor(
def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None):
num_rows = len(items)
if qry_max_count > max_count and num_rows == qry_max_count:
assert not (lt and gt)
assert not (lt is not None and gt is not None)
items = items[:-1]
if desc:
items.reverse()
@@ -257,8 +257,8 @@ def query_user_messages_cursor(
db: Session = Depends(deps.get_db),
):
return get_messages_cursor(
lt=lt,
gt=gt,
before=lt,
after=gt,
user_id=user_id,
only_roots=only_roots,
include_deleted=include_deleted,