From 0c5e2fc45deb0f8491d34229b728428e4ef1f964 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 13:00:59 +0100 Subject: [PATCH 1/4] Show last updated on leaderboard --- website/public/locales/en/common.json | 3 ++- .../LeaderboardGridCell.tsx | 27 ++++++++++++++----- website/src/pages/api/leaderboard.ts | 6 ++--- website/src/types/Leaderboard.ts | 1 + 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index e18eb8ec..de764cc2 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -13,5 +13,6 @@ "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", - "title": "Open Assistant" + "title": "Open Assistant", + "last_updated_at": "Last updated at: {{val, datetime}}" } diff --git a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx index df18735d..9750a851 100644 --- a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx +++ b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx @@ -1,8 +1,9 @@ -import { Table, TableContainer, Tbody, Td, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react"; -import React from "react"; +import { Table, TableContainer, Tbody, Td, Text, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import React, { useMemo } from "react"; import { useTable } from "react-table"; import { get } from "src/lib/api"; -import { LeaderboardEntity, LeaderboardTimeFrame } from "src/types/Leaderboard"; +import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import useSWRImmutable from "swr/immutable"; const columns = [ @@ -26,13 +27,26 @@ const columns = [ * Presents a grid of leaderboard entries with more detailed information. */ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => { - const { data } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { - fallbackData: [], + const { t } = useTranslation(); + const { data: reply } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { revalidateOnMount: true, }); + + const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ + columns, + data: reply?.leaderboard ?? [], + }); + const backgroundColor = useColorModeValue("white", "gray.800"); - const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ columns, data }); + const lastUpdated = useMemo(() => { + const val = new Date(reply?.last_updated); + return t("last_updated_at", { val, formatParams: { val: { dateStyle: "full", timeStyle: "short" } } }); + }, [t, reply?.last_updated]); + + if (!reply) { + return null; + } return ( @@ -66,6 +80,7 @@ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) })} + {lastUpdated} ); }; diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index 592f3da5..1ddf947e 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -6,9 +6,9 @@ import { LeaderboardTimeFrame } from "src/types/Leaderboard"; * Returns the set of valid labels that can be applied to messages. */ const handler = withoutRole("banned", async (req, res) => { - const time_frame = (req.query.time_frame as LeaderboardTimeFrame) || LeaderboardTimeFrame.day; - const { leaderboard } = await oasstApiClient.fetch_leaderboard(time_frame); - res.status(200).json(leaderboard); + const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; + const info = await oasstApiClient.fetch_leaderboard(time_frame); + res.status(200).json(info); }); export default handler; diff --git a/website/src/types/Leaderboard.ts b/website/src/types/Leaderboard.ts index 21c91766..5c0acfc3 100644 --- a/website/src/types/Leaderboard.ts +++ b/website/src/types/Leaderboard.ts @@ -12,6 +12,7 @@ export const enum LeaderboardTimeFrame { } export interface LeaderboardReply { time_frame: LeaderboardTimeFrame; + last_updated: string; // date time iso string leaderboard: LeaderboardEntity[]; } From 2802bc15812fbc75d84405e8c6e3d45871b00145 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 13:51:47 +0100 Subject: [PATCH 2/4] Render whitespace in messages --- website/src/components/Messages/MessageTableEntry.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index d18bd910..1205991e 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -50,6 +50,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) { bg={item.is_assistant ? backgroundColor : backgroundColor2} onClick={props.enabled && goToMessage} _hover={props.enabled && { cursor: "pointer", opacity: 0.9 }} + whiteSpace="pre-wrap" > {inlineAvatar && avatar} {item.text} From 2d21b65ed0afb3bbeeac83f3f37c8075cf854a2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 20 Jan 2023 19:58:33 +0100 Subject: [PATCH 3/4] Add lang-tag based task selection (lang-separation) (#863) * lang based task selection * use BCP 47 instead of ISO 639-1 * add Field(None, nullable=True) * update migration script down_revision --- ...cc_use_en_instead_en_us_as_default_lang.py | 29 +++++++++ backend/main.py | 2 + backend/oasst_backend/api/v1/tasks.py | 5 +- backend/oasst_backend/api/v1/utils.py | 6 +- backend/oasst_backend/models/message.py | 2 +- backend/oasst_backend/prompt_repository.py | 4 ++ backend/oasst_backend/tree_manager.py | 64 +++++++++++++------ oasst-shared/oasst_shared/schemas/protocol.py | 3 + oasst-shared/tests/test_oasst_api_client.py | 1 + 9 files changed, 90 insertions(+), 26 deletions(-) create mode 100644 backend/alembic/versions/2023_01_20_1650-160ac010efcc_use_en_instead_en_us_as_default_lang.py diff --git a/backend/alembic/versions/2023_01_20_1650-160ac010efcc_use_en_instead_en_us_as_default_lang.py b/backend/alembic/versions/2023_01_20_1650-160ac010efcc_use_en_instead_en_us_as_default_lang.py new file mode 100644 index 00000000..7e12dddb --- /dev/null +++ b/backend/alembic/versions/2023_01_20_1650-160ac010efcc_use_en_instead_en_us_as_default_lang.py @@ -0,0 +1,29 @@ +"""use 'en' instead 'en-US' as default lang + +Revision ID: 160ac010efcc +Revises: 4f26fec4d204 +Create Date: 2023-01-20 16:50:00 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "160ac010efcc" +down_revision = "4f26fec4d204" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "lang") + op.add_column("message", sa.Column("lang", sa.String(length=32), server_default="en", nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message", "lang") + op.add_column("message", sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False)) + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index ce316f41..147f8e30 100644 --- a/backend/main.py +++ b/backend/main.py @@ -128,6 +128,7 @@ if settings.DEBUG_USE_SEED_DATA: user_message_id: str parent_message_id: Optional[str] text: str + lang: Optional[str] role: str tree_state: Optional[message_tree_state.State] @@ -184,6 +185,7 @@ if settings.DEBUG_USE_SEED_DATA: tr.bind_frontend_message_id(task.id, msg.task_message_id) message = pr.store_text_reply( msg.text, + msg.lang, msg.task_message_id, msg.user_message_id, review_count=5, diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 9e9118da..7875ac83 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -39,7 +39,7 @@ def request_task( pr.ensure_user_is_enabled() tm = TreeManager(db, pr) - task, message_tree_id, parent_message_id = tm.next_task(request.type) + task, message_tree_id, parent_message_id = tm.next_task(desired_task_type=request.type, lang=request.lang) pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective) except OasstError: @@ -54,6 +54,7 @@ def request_task( def tasks_availability( *, user: Optional[protocol_schema.User] = None, + lang: Optional[str] = "en", db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), ): @@ -62,7 +63,7 @@ def tasks_availability( try: pr = PromptRepository(db, api_client, client_user=user) tm = TreeManager(db, pr) - return tm.determine_task_availability() + return tm.determine_task_availability(lang) except OasstError: raise diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 035f6d19..49cbbd29 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -10,6 +10,7 @@ def prepare_message(m: Message) -> protocol.Message: frontend_message_id=m.frontend_message_id, parent_id=m.parent_id, text=m.text, + lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, ) @@ -22,10 +23,11 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]: def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]: return [ protocol.ConversationMessage( - text=message.text, - is_assistant=(message.role == "assistant"), id=message.id, frontend_message_id=message.frontend_message_id, + text=message.text, + lang=message.lang, + is_assistant=(message.role == "assistant"), ) for message in messages ] diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index b03c8534..d0b1d869 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -38,7 +38,7 @@ class Message(SQLModel, table=True): payload: Optional[PayloadContainer] = Field( sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True) ) - lang: str = Field(nullable=False, max_length=200, default="en-US") + lang: str = Field(sa_column=sa.Column(sa.String(32), server_default="en", nullable=False)) depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 35e1ece4..5e928d8b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -85,6 +85,7 @@ class PromptRepository: task_id: UUID, role: str, payload: db_payload.MessagePayload, + lang: str, payload_type: str = None, depth: int = 0, review_count: int = 0, @@ -107,6 +108,7 @@ class PromptRepository: api_client_id=self.api_client.id, payload_type=payload_type, payload=PayloadContainer(payload=payload), + lang=lang, depth=depth, review_count=review_count, review_result=review_result, @@ -146,6 +148,7 @@ class PromptRepository: def store_text_reply( self, text: str, + lang: str, frontend_message_id: str, user_frontend_message_id: str, review_count: int = 0, @@ -209,6 +212,7 @@ class PromptRepository: task_id=task.id, role=role, payload=db_payload.MessagePayload(text=text), + lang=lang or "en", depth=depth, review_count=review_count, review_result=review_result, diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index f4338048..026bb564 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -190,14 +190,18 @@ class TreeManager: return task_count_by_type - def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]: + def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]: self.pr.ensure_user_is_enabled() - num_active_trees = self.query_num_active_trees() - extendible_parents = self.query_extendible_parents() - prompts_need_review = self.query_prompts_need_review() - replies_need_review = self.query_replies_need_review() - incomplete_rankings = self.query_incomplete_rankings() + if not lang: + lang = "en" + logger.warning("Task availability request without lang tag received, assuming lang='en'.") + + num_active_trees = self.query_num_active_trees(lang=lang) + extendible_parents = self.query_extendible_parents(lang=lang) + prompts_need_review = self.query_prompts_need_review(lang=lang) + replies_need_review = self.query_replies_need_review(lang=lang) + incomplete_rankings = self.query_incomplete_rankings(lang=lang) return self._determine_task_availability_internal( num_active_trees=num_active_trees, @@ -208,23 +212,29 @@ class TreeManager: ) def next_task( - self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random + self, + desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random, + lang: str = "en", ) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: - logger.debug("TreeManager.next_task()") + logger.debug(f"TreeManager.next_task({desired_task_type=}, {lang=})") self.pr.ensure_user_is_enabled() - num_active_trees = self.query_num_active_trees() - prompts_need_review = self.query_prompts_need_review() - replies_need_review = self.query_replies_need_review() - extendible_parents = self.query_extendible_parents() + if not lang: + lang = "en" + logger.warning("Task request without lang tag received, assuming 'en'.") - incomplete_rankings = self.query_incomplete_rankings() + num_active_trees = self.query_num_active_trees(lang=lang) + prompts_need_review = self.query_prompts_need_review(lang=lang) + replies_need_review = self.query_replies_need_review(lang=lang) + extendible_parents = self.query_extendible_parents(lang=lang) + + incomplete_rankings = self.query_incomplete_rankings(lang=lang) if not self.cfg.rank_prompter_replies: incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings)) - active_tree_sizes = self.query_extendible_trees() + active_tree_sizes = self.query_extendible_trees(lang=lang) # determine type of task to generate num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes) @@ -458,6 +468,7 @@ class TreeManager: # here we store the text reply in the database message = pr.store_text_reply( text=interaction.text, + lang=interaction.lang, frontend_message_id=interaction.message_id, user_frontend_message_id=interaction.user_message_id, ) @@ -665,7 +676,7 @@ class TreeManager: # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) - def query_prompts_need_review(self) -> list[Message]: + def query_prompts_need_review(self, lang: str) -> list[Message]: """ Select initial prompt messages with less then required rankings in active message tree (active == True in message_tree_state) @@ -682,6 +693,7 @@ class TreeManager: not_(Message.deleted), Message.review_count < self.cfg.num_reviews_initial_prompt, Message.parent_id.is_(None), + Message.lang == lang, ) ) @@ -690,7 +702,7 @@ class TreeManager: return qry.all() - def query_replies_need_review(self) -> list[Message]: + def query_replies_need_review(self, lang: str) -> list[Message]: """ Select child messages (parent_id IS NOT NULL) with less then required rankings in active message tree (active == True in message_tree_state) @@ -707,6 +719,7 @@ class TreeManager: not_(Message.deleted), Message.review_count < self.cfg.num_reviews_reply, Message.parent_id.is_not(None), + Message.lang == lang, ) ) @@ -724,13 +737,14 @@ FROM message_tree_state mts WHERE mts.active -- only consider active trees AND mts.state = :ranking_state -- message tree must be in ranking state AND m.review_result -- must be reviewed + AND m.lang = :lang -- matches lang AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts GROUP BY m.parent_id, m.role HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ - def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]: + def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]: """Query parents which have childern that need further rankings""" r = self.db.execute( @@ -738,6 +752,7 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings { "num_required_rankings": self.cfg.num_required_rankings, "ranking_state": message_tree_state.State.RANKING, + "lang": lang, }, ) return [IncompleteRankingsRow.from_orm(x) for x in r.all()] @@ -753,13 +768,14 @@ WHERE mts.active -- only consider active trees AND NOT m.deleted -- ignore deleted messages as parents AND m.depth < mts.max_depth -- ignore leaf nodes as parents AND m.review_result -- parent node must have positive review + AND m.lang = :lang -- parent matches lang AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children """ - def query_extendible_parents(self) -> list[ExtendibleParentRow]: + def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]: """Query parent messages that have not reached the maximum number of replies.""" r = self.db.execute( @@ -767,6 +783,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children { "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, + "lang": lang, }, ) return [ExtendibleParentRow.from_orm(x) for x in r.all()] @@ -787,7 +804,7 @@ GROUP BY m.message_tree_id, mts.goal_tree_size HAVING COUNT(m.id) < mts.goal_tree_size """ - def query_extendible_trees(self) -> list[ActiveTreeSizeRow]: + def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]: """Query size of active message trees in growing state.""" r = self.db.execute( @@ -795,6 +812,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size { "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, + "lang": lang, }, ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] @@ -894,8 +912,12 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) - def query_num_active_trees(self) -> int: - query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active) + def query_num_active_trees(self, lang: str) -> int: + query = ( + self.db.query(func.count(MessageTreeState.message_tree_id)) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter(MessageTreeState.active, Message.lang == lang) + ) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 388bd0d6..d31f1c79 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -43,6 +43,7 @@ class ConversationMessage(BaseModel): id: Optional[UUID] = None frontend_message_id: Optional[str] = None text: str + lang: Optional[str] # BCP 47 is_assistant: bool @@ -72,6 +73,7 @@ class TaskRequest(BaseModel): # this is optional. https://github.com/pydantic/pydantic/issues/1270 user: Optional[User] = Field(None, nullable=True) collective: bool = False + lang: Optional[str] = Field(None, nullable=True) # BCP 47 class TaskAck(BaseModel): @@ -266,6 +268,7 @@ class TextReplyToMessage(Interaction): message_id: str user_message_id: str text: constr(min_length=1, strip_whitespace=True) + lang: Optional[str] # BCP 47 class MessageRating(Interaction): diff --git a/oasst-shared/tests/test_oasst_api_client.py b/oasst-shared/tests/test_oasst_api_client.py index fdb743ce..e1515123 100644 --- a/oasst-shared/tests/test_oasst_api_client.py +++ b/oasst-shared/tests/test_oasst_api_client.py @@ -73,6 +73,7 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient): message_id="123", user_message_id="321", text="This is my reply", + lang="en", user=protocol_schema.User( id="123", display_name="lomz", From 94e5d50537dad94bac6f22502ae25cf54edbc0b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 21 Jan 2023 00:29:53 +0100 Subject: [PATCH 4/4] add users/cursor keyset_pagination endpoint (#866) * add users/cursor endpoint * add messages/cursor endpoint * add user/{user_id}/messages/cursor, frontend_user/{auth_method}/{username}/messages/cursor * user regex to parse cursor value --- .../oasst_backend/api/v1/frontend_users.py | 41 +++++++- backend/oasst_backend/api/v1/messages.py | 87 +++++++++++++--- backend/oasst_backend/api/v1/users.py | 99 ++++++++++++++++++- backend/oasst_backend/api/v1/utils.py | 6 ++ backend/oasst_backend/prompt_repository.py | 84 ++++++++++------ backend/oasst_backend/user_repository.py | 2 + .../exceptions/oasst_api_error.py | 4 + oasst-shared/oasst_shared/schemas/protocol.py | 16 +++ 8 files changed, 287 insertions(+), 52 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index f2fc3181..5d96aa90 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -5,6 +5,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils +from oasst_backend.api.v1.messages import get_messages_cursor from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository from oasst_backend.user_repository import UserRepository @@ -76,20 +77,47 @@ def query_frontend_user_messages( Query frontend user messages. """ pr = PromptRepository(db, api_client) - messages = pr.query_messages( + messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, api_client_id=api_client_id, desc=desc, limit=max_count, - start_date=start_date, - end_date=end_date, + gte_created_date=start_date, + lte_created_date=end_date, only_roots=only_roots, deleted=None if include_deleted else False, ) return utils.prepare_message_list(messages) +@router.get("/{auth_method}/{username}/messages/cursor", response_model=protocol.MessagePage) +def query_frontend_user_messages_cursor( + auth_method: str, + username: str, + lt: Optional[str] = None, + gt: Optional[str] = None, + only_roots: Optional[bool] = False, + include_deleted: Optional[bool] = False, + max_count: Optional[int] = Query(10, gt=0, le=1000), + desc: Optional[bool] = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + return get_messages_cursor( + lt=lt, + gt=gt, + auth_method=auth_method, + username=username, + only_roots=only_roots, + include_deleted=include_deleted, + max_count=max_count, + desc=desc, + api_client=api_client, + db=db, + ) + + @router.delete("/{auth_method}/{username}/messages", status_code=HTTP_204_NO_CONTENT) def mark_frontend_user_messages_deleted( auth_method: str, @@ -98,5 +126,10 @@ def mark_frontend_user_messages_deleted( db: Session = Depends(deps.get_db), ): pr = PromptRepository(db, api_client) - messages = pr.query_messages(auth_method=auth_method, username=username, api_client_id=api_client.id) + messages = pr.query_messages_ordered_by_created_date( + auth_method=auth_method, + username=username, + api_client_id=api_client.id, + limit=None, + ) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 409240cb..2dcca64b 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -1,4 +1,5 @@ -import datetime +from datetime import datetime +from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query @@ -6,8 +7,8 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol -from oasst_shared.utils import unaware_to_utc from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -16,31 +17,30 @@ router = APIRouter() @router.get("/", response_model=list[protocol.Message]) def query_messages( - username: str = None, - api_client_id: str = None, - max_count: int = Query(10, gt=0, le=1000), - start_date: datetime.datetime = None, - end_date: datetime.datetime = None, - only_roots: bool = False, - desc: bool = True, - allow_deleted: bool = False, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client_id: Optional[str] = None, + max_count: Optional[int] = Query(10, gt=0, le=1000), + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + only_roots: Optional[bool] = False, + desc: Optional[bool] = True, + allow_deleted: Optional[bool] = False, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): """ Query messages. """ - start_date = unaware_to_utc(start_date) - end_date = unaware_to_utc(end_date) - pr = PromptRepository(db, api_client) - messages = pr.query_messages( + messages = pr.query_messages_ordered_by_created_date( + auth_method=auth_method, username=username, api_client_id=api_client_id, desc=desc, limit=max_count, - start_date=start_date, - end_date=end_date, + gte_created_date=start_date, + lte_created_date=end_date, only_roots=only_roots, deleted=None if allow_deleted else False, ) @@ -48,6 +48,61 @@ def query_messages( return utils.prepare_message_list(messages) +@router.get("/cursor", response_model=protocol.MessagePage) +def get_messages_cursor( + lt: Optional[str] = None, + gt: Optional[str] = None, + user_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client_id: Optional[str] = None, + only_roots: Optional[bool] = False, + include_deleted: Optional[bool] = False, + max_count: Optional[int] = Query(10, gt=0, le=1000), + desc: Optional[bool] = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + def split_cursor(x: str | None) -> tuple[datetime, UUID]: + if not x: + return None, None + try: + m = utils.split_uuid_pattern.match(x) + if m: + return datetime.fromisoformat(m[2]), UUID(m[1]) + return datetime.fromisoformat(x), None + except ValueError: + raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE) + + lte_created_date, lt_id = split_cursor(lt) + gte_created_date, gt_id = split_cursor(gt) + + pr = PromptRepository(db, api_client) + messages = pr.query_messages_ordered_by_created_date( + user_id=user_id, + auth_method=auth_method, + username=username, + api_client_id=api_client_id, + gte_created_date=gte_created_date, + gt_id=gt_id, + lte_created_date=lte_created_date, + lt_id=lt_id, + only_roots=only_roots, + deleted=None if include_deleted else False, + desc=desc, + limit=max_count, + ) + + items = utils.prepare_message_list(messages) + n, p = None, None + if len(items) > 0: + p = str(items[0].id) + "$" + items[0].created_date.isoformat() + n = str(items[-1].id) + "$" + items[-1].created_date.isoformat() + + order = "desc" if desc else "asc" + return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items) + + @router.get("/{message_id}", response_model=protocol.Message) def get_message( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 0b31495a..63a55691 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -5,10 +5,12 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils +from oasst_backend.api.v1.messages import get_messages_cursor from oasst_backend.models import ApiClient, User from oasst_backend.prompt_repository import PromptRepository from oasst_backend.user_repository import UserRepository from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -70,6 +72,70 @@ def get_users_ordered_by_display_name( return [u.to_protocol_frontend_user() for u in users] +@router.get("/cursor", response_model=protocol.FrontEndUserPage) +def get_users_cursor( + lt: Optional[str] = None, + gt: Optional[str] = None, + sort_key: Optional[str] = Query("username", max_length=32), + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client_id: Optional[UUID] = None, + search_text: Optional[str] = None, + auth_method: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + def split_cursor(x: str | None) -> tuple[str, UUID]: + if not x: + return None, None + m = utils.split_uuid_pattern.match(x) + if m: + return m[2], UUID(m[1]) + return x, None + + items: list[protocol.FrontEndUser] + n, p = None, None + if sort_key == "username": + lte_username, lt_id = split_cursor(lt) + gte_username, gt_id = split_cursor(gt) + items = get_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + max_count=max_count, + api_client=api_client, + db=db, + ) + if len(items) > 0: + p = str(items[0].user_id) + "$" + items[0].id + n = str(items[-1].user_id) + "$" + items[-1].id + elif sort_key == "display_name": + lte_display_name, lt_id = split_cursor(lt) + gte_display_name, gt_id = split_cursor(gt) + items = get_users_ordered_by_display_name( + api_client_id=api_client_id, + gte_display_name=gte_display_name, + gt_id=gt_id, + lte_display_name=lte_display_name, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + max_count=max_count, + api_client=api_client, + db=db, + ) + if len(items) > 0: + p = str(items[0].user_id) + "$" + items[0].display_name + n = str(items[-1].user_id) + "$" + items[-1].display_name + else: + raise OasstError(f"Unsupported sort key: '{sort_key}'", OasstErrorCode.SORT_KEY_UNSUPPORTED) + + return protocol.FrontEndUserPage(prev=p, next=n, sort_key=sort_key, order="asc", items=items) + + @router.get("/{user_id}", response_model=protocol.FrontEndUser) def get_user( user_id: UUID, @@ -130,13 +196,13 @@ def query_user_messages( Query user messages. """ pr = PromptRepository(db, api_client) - messages = pr.query_messages( + messages = pr.query_messages_ordered_by_created_date( user_id=user_id, api_client_id=api_client_id, desc=desc, limit=max_count, - start_date=start_date, - end_date=end_date, + gte_created_date=start_date, + lte_created_date=end_date, only_roots=only_roots, deleted=None if include_deleted else False, ) @@ -144,12 +210,37 @@ def query_user_messages( return utils.prepare_message_list(messages) +@router.get("/{user_id}/messages/cursor", response_model=protocol.MessagePage) +def query_user_messages_cursor( + user_id: Optional[UUID], + lt: Optional[str] = None, + gt: Optional[str] = None, + only_roots: Optional[bool] = False, + include_deleted: Optional[bool] = False, + max_count: Optional[int] = Query(10, gt=0, le=1000), + desc: Optional[bool] = False, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + return get_messages_cursor( + lt=lt, + gt=gt, + user_id=user_id, + only_roots=only_roots, + include_deleted=include_deleted, + max_count=max_count, + desc=desc, + api_client=api_client, + db=db, + ) + + @router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT) def mark_user_messages_deleted( user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): pr = PromptRepository(db, api_client) - messages = pr.query_messages(user_id=user_id) + messages = pr.query_messages_ordered_by_created_date(user_id=user_id, limit=None) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 49cbbd29..99161e32 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -1,3 +1,4 @@ +import re from uuid import UUID from oasst_backend.models import Message @@ -43,3 +44,8 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree: tree_messages.append(prepare_message(message)) return protocol.MessageTree(id=tree_id, messages=tree_messages) + + +split_uuid_pattern = re.compile( + r"^([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})\$(.*)$" +) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 5e928d8b..abf5b721 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -1,6 +1,6 @@ -import datetime import random from collections import defaultdict +from datetime import datetime from http import HTTPStatus from typing import List, Optional, Tuple from uuid import UUID, uuid4 @@ -28,7 +28,8 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats -from sqlmodel import Session, func, not_, text, update +from oasst_shared.utils import unaware_to_utc +from sqlmodel import Session, and_, func, not_, or_, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -664,58 +665,85 @@ class PromptRepository: max_message = max(tree, key=lambda m: m.children_count) return max_message, [m for m in tree if m.parent_id == max_message.id] - def query_messages( + def query_messages_ordered_by_created_date( self, user_id: Optional[UUID] = None, auth_method: Optional[str] = None, username: Optional[str] = None, api_client_id: Optional[UUID] = None, - desc: bool = True, - limit: Optional[int] = 10, - start_date: Optional[datetime.datetime] = None, - end_date: Optional[datetime.datetime] = None, + gte_created_date: Optional[datetime] = None, + gt_id: Optional[UUID] = None, + lte_created_date: Optional[datetime] = None, + lt_id: Optional[UUID] = None, only_roots: bool = False, deleted: Optional[bool] = None, + desc: bool = False, + limit: Optional[int] = 100, ) -> list[Message]: - if not self.api_client.trusted and not api_client_id: - # Let unprivileged api clients query their own messages without api_client_id being set - api_client_id = self.api_client.id + if not self.api_client.trusted: + if not api_client_id: + # Let unprivileged api clients query their own messages without api_client_id being set + api_client_id = self.api_client.id - if not self.api_client.trusted and api_client_id != self.api_client.id: - # Unprivileged api client asks for foreign messages - raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + if api_client_id != self.api_client.id: + # Unprivileged api client asks for foreign messages + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - messages = self.db.query(Message) + qry = self.db.query(Message) if user_id: - messages = messages.filter(Message.user_id == user_id) + qry = qry.filter(Message.user_id == user_id) if username or auth_method: if not username and auth_method: raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED) - messages = messages.join(User) - messages = messages.filter(User.username == username, User.auth_method == auth_method) + qry = qry.join(User) + qry = qry.filter(User.username == username, User.auth_method == auth_method) if api_client_id: - messages = messages.filter(Message.api_client_id == api_client_id) + qry = qry.filter(Message.api_client_id == api_client_id) - if start_date: - messages = messages.filter(Message.created_date >= start_date) - if end_date: - messages = messages.filter(Message.created_date < end_date) + gte_created_date = unaware_to_utc(gte_created_date) + lte_created_date = unaware_to_utc(lte_created_date) + + if gte_created_date is not None: + if gt_id: + qry = qry.filter( + or_( + Message.created_date > gte_created_date, + and_(Message.created_date == gte_created_date, Message.id > gt_id), + ) + ) + else: + qry = qry.filter(Message.created_date >= gte_created_date) + elif gt_id: + raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_created_date is not None: + if lt_id: + qry = qry.filter( + or_( + Message.created_date < lte_created_date, + and_(Message.created_date == lte_created_date, Message.id < lt_id), + ) + ) + else: + qry = qry.filter(Message.created_date <= lte_created_date) + elif lt_id: + raise OasstError("Need id and date for keyset pagination", OasstErrorCode.GENERIC_ERROR) if only_roots: - messages = messages.filter(Message.parent_id.is_(None)) + qry = qry.filter(Message.parent_id.is_(None)) if deleted is not None: - messages = messages.filter(Message.deleted == deleted) + qry = qry.filter(Message.deleted == deleted) if desc: - messages = messages.order_by(Message.created_date.desc()) + qry = qry.order_by(Message.created_date.desc(), Message.id.desc()) else: - messages = messages.order_by(Message.created_date.asc()) + qry = qry.order_by(Message.created_date.asc(), Message.id.asc()) if limit is not None: - messages = messages.limit(limit) + qry = qry.limit(limit) - return messages.all() + return qry.all() def update_children_counts(self, message_tree_id: UUID): sql_update_children_count = """ diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index c0c2a88d..1e9ac78f 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -202,9 +202,11 @@ class UserRepository: ) -> list[User]: if not self.api_client.trusted: if not api_client_id: + # Let unprivileged api clients query their own users without api_client_id being set api_client_id = self.api_client.id if api_client_id != self.api_client.id: + # Unprivileged api client asks for foreign users raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) qry = self.db.query(User).order_by(User.display_name, User.id) diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index e8cd2359..bed6f942 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -19,6 +19,10 @@ class OasstErrorCode(IntEnum): API_CLIENT_NOT_AUTHORIZED = 2 ROOT_TOKEN_NOT_AUTHORIZED = 3 DATABASE_MAX_RETRIES_EXHAUSTED = 4 + + SORT_KEY_UNSUPPORTED = 100 + INVALID_CURSOR_VALUE = 101 + TOO_MANY_REQUESTS = 429 SERVER_ERROR0 = 500 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index d31f1c79..f2164e8f 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -37,6 +37,18 @@ class FrontEndUser(User): created_date: Optional[datetime] = None +class PageResult(BaseModel): + prev: str | None + next: str | None + sort_key: str + items: list + order: Literal["asc", "desc"] + + +class FrontEndUserPage(PageResult): + items: list[FrontEndUser] + + class ConversationMessage(BaseModel): """Represents a message in a conversation between the user and the assistant.""" @@ -58,6 +70,10 @@ class Message(ConversationMessage): created_date: Optional[datetime] = None +class MessagePage(PageResult): + items: list[Message] + + class MessageTree(BaseModel): """All messages belonging to the same message tree."""