mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
fix desc ordering of message cursor endpoints
This commit is contained in:
@@ -105,8 +105,8 @@ def query_frontend_user_messages_cursor(
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
before=lt,
|
||||
after=gt,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
only_roots=only_roots,
|
||||
|
||||
@@ -50,8 +50,8 @@ def query_messages(
|
||||
|
||||
@router.get("/cursor", response_model=protocol.MessagePage)
|
||||
def get_messages_cursor(
|
||||
lt: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
@@ -63,6 +63,8 @@ def get_messages_cursor(
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
assert max_count is not None
|
||||
|
||||
def split_cursor(x: str | None) -> tuple[datetime, UUID]:
|
||||
if not x:
|
||||
return None, None
|
||||
@@ -74,11 +76,21 @@ def get_messages_cursor(
|
||||
except ValueError:
|
||||
raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE)
|
||||
|
||||
lte_created_date, lt_id = split_cursor(lt)
|
||||
gte_created_date, gt_id = split_cursor(gt)
|
||||
if desc:
|
||||
gte_created_date, gt_id = split_cursor(before)
|
||||
lte_created_date, lt_id = split_cursor(after)
|
||||
query_desc = not (before is not None and not after)
|
||||
else:
|
||||
lte_created_date, lt_id = split_cursor(before)
|
||||
gte_created_date, gt_id = split_cursor(after)
|
||||
query_desc = before is not None and not after
|
||||
|
||||
print(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}")
|
||||
|
||||
qry_max_count = max_count + 1 if before is None or after is None else max_count
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages_ordered_by_created_date(
|
||||
items = pr.query_messages_ordered_by_created_date(
|
||||
user_id=user_id,
|
||||
auth_method=auth_method,
|
||||
username=username,
|
||||
@@ -89,22 +101,30 @@ def get_messages_cursor(
|
||||
lt_id=lt_id,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
desc=query_desc,
|
||||
limit=qry_max_count,
|
||||
)
|
||||
|
||||
items = utils.prepare_message_list(messages)
|
||||
num_rows = len(items)
|
||||
if qry_max_count > max_count and num_rows == qry_max_count:
|
||||
assert not (before and after)
|
||||
items = items[:-1]
|
||||
|
||||
if desc != query_desc:
|
||||
items.reverse()
|
||||
|
||||
items = utils.prepare_message_list(items)
|
||||
n, p = None, None
|
||||
if len(items) > 0:
|
||||
if len(items) == max_count or gte_created_date:
|
||||
if (num_rows > max_count and before) or after:
|
||||
p = str(items[0].id) + "$" + items[0].created_date.isoformat()
|
||||
if len(items) == max_count or lte_created_date:
|
||||
if num_rows > max_count or before:
|
||||
n = str(items[-1].id) + "$" + items[-1].created_date.isoformat()
|
||||
else:
|
||||
if gte_created_date:
|
||||
p = gte_created_date.isoformat()
|
||||
if lte_created_date:
|
||||
n = lte_created_date.isoformat()
|
||||
if after:
|
||||
p = lte_created_date.isoformat() if desc else gte_created_date.isoformat()
|
||||
if before:
|
||||
n = gte_created_date.isoformat() if desc else lte_created_date.isoformat()
|
||||
|
||||
order = "desc" if desc else "asc"
|
||||
return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items)
|
||||
|
||||
@@ -100,7 +100,7 @@ def get_users_cursor(
|
||||
|
||||
items: list[protocol.FrontEndUser]
|
||||
qry_max_count = max_count + 1 if lt is None or gt is None else max_count
|
||||
desc = lt and not gt
|
||||
desc = lt is not None and not gt
|
||||
|
||||
def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]):
|
||||
p, n = None, None
|
||||
@@ -119,7 +119,7 @@ def get_users_cursor(
|
||||
def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None):
|
||||
num_rows = len(items)
|
||||
if qry_max_count > max_count and num_rows == qry_max_count:
|
||||
assert not (lt and gt)
|
||||
assert not (lt is not None and gt is not None)
|
||||
items = items[:-1]
|
||||
if desc:
|
||||
items.reverse()
|
||||
@@ -257,8 +257,8 @@ def query_user_messages_cursor(
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
return get_messages_cursor(
|
||||
lt=lt,
|
||||
gt=gt,
|
||||
before=lt,
|
||||
after=gt,
|
||||
user_id=user_id,
|
||||
only_roots=only_roots,
|
||||
include_deleted=include_deleted,
|
||||
|
||||
Reference in New Issue
Block a user