diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py index e951a28e..31c74d9d 100644 --- a/text-frontend/auto_main.py +++ b/text-frontend/auto_main.py @@ -6,12 +6,10 @@ from uuid import uuid4 import requests import typer +from faker import Faker app = typer.Typer() - - -# debug constants -USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"} +fake = Faker() def _random_message_id(): @@ -26,19 +24,11 @@ def _render_message(message: dict) -> str: @app.command() -def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): +def main( + backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234", random_users: int = 1, task_per_user: int = 10 +): """automates tasks""" - # make sure dummy user has accepted the terms of service - create_user_request = dict(USER) - create_user_request["tos_acceptance"] = True - response = requests.post( - f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} - ) - response.raise_for_status() - user = response.json() - typer.echo(f"user: {user}") - def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status() @@ -60,204 +50,219 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): print(shuffled) return ranks - tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] - q = 0 - while tasks: - task = tasks.pop(0) - print(task) + for i in range(int(random_users)): + name = fake.name() + USER = {"id": name, "display_name": name, "auth_method": "local"} - match (task["type"]): - case "initial_prompt": - typer.echo("Please provide an initial prompt to the assistant.") - if task["hint"]: - typer.echo(f"Hint: {task['hint']}") - # acknowledge task - message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + create_user_request = dict(USER) + # make sure dummy user has accepted the terms of service + create_user_request["tos_acceptance"] = True + response = requests.post( + f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} + ) + response.raise_for_status() + user = response.json() + typer.echo(f"user: {user}") + q = 0 - prompt = gen_random_text() - user_message_id = _random_message_id() - # send interaction - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "text_reply_to_message", - "message_id": message_id, - "task_id": task["id"], - "user_message_id": user_message_id, - "text": prompt, - "user": USER, - }, - ) - tasks.append(new_task) + tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] - case "label_initial_prompt": - typer.echo("Label the following prompt:") - typer.echo(task["prompt"]) - # acknowledge task - message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + while tasks: + task = tasks.pop(0) + print(task) - valid_labels = task["valid_labels"] - mandatory_labels = task["mandatory_labels"] + match (task["type"]): + case "initial_prompt": + typer.echo("Please provide an initial prompt to the assistant.") + if task["hint"]: + typer.echo(f"Hint: {task['hint']}") + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - labels_dict = None - if task["mode"] == "simple" and len(valid_labels) == 1: - answer = random.choice([True, False]) - labels_dict = {valid_labels[0]: 1 if answer else 0} - else: - labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) - for l in mandatory_labels: - if l not in labels: - labels.append(l) - labels_dict = {label: random.random() for label in valid_labels} - if random.random() < 0.9: - labels_dict["spam"] = 0 - labels_dict["lang_mismatch"] = 0 + prompt = gen_random_text() + user_message_id = _random_message_id() + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": prompt, + "user": USER, + }, + ) + tasks.append(new_task) - # send labels - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "text_labels", - "message_id": task["message_id"], - "task_id": task["id"], - "text": task["prompt"], - "labels": labels_dict, - "user": USER, - }, - ) - tasks.append(new_task) - case "prompter_reply": - # acknowledge task - message_id = _random_message_id() - user_message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - # send interaction - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "text_reply_to_message", - "message_id": message_id, - "task_id": task["id"], - "user_message_id": user_message_id, - "text": gen_random_text(), - "user": USER, - }, - ) - tasks.append(new_task) + case "label_initial_prompt": + typer.echo("Label the following prompt:") + typer.echo(task["prompt"]) + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - case "assistant_reply": - # acknowledge task - message_id = _random_message_id() - user_message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - # send interaction - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "text_reply_to_message", - "message_id": message_id, - "task_id": task["id"], - "user_message_id": user_message_id, - "text": gen_random_text(), - "user": USER, - }, - ) - tasks.append(new_task) + valid_labels = task["valid_labels"] + mandatory_labels = task["mandatory_labels"] - case "rank_prompter_replies" | "rank_assistant_replies": - # acknowledge task - message_id = _random_message_id() - user_message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - # send interaction - ranking = gen_random_ranking(task["replies"]) - print(ranking) - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "message_ranking", - "message_id": message_id, - "task_id": task["id"], - "ranking": ranking, - "user": USER, - }, - ) - tasks.append(new_task) + labels_dict = None + if task["mode"] == "simple" and len(valid_labels) == 1: + answer = random.choice([True, False]) + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) + for l in mandatory_labels: + if l not in labels: + labels.append(l) + labels_dict = {label: random.random() for label in valid_labels} + if random.random() < 0.9: + labels_dict["spam"] = 0 + labels_dict["lang_mismatch"] = 0 - case "rank_initial_prompts": - # acknowledge task - message_id = _random_message_id() - user_message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - # send interaction - ranking = gen_random_ranking(task["prompots"]) - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "message_ranking", - "message_id": message_id, - "ranking": ranking, - "user": USER, - }, - ) - tasks.append(new_task) + # send labels + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "task_id": task["id"], + "text": task["prompt"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + case "prompter_reply": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": gen_random_text(), + "user": USER, + }, + ) + tasks.append(new_task) - case "label_prompter_reply" | "label_assistant_reply": - # acknowledge task - typer.echo("Here is the conversation so far:") - for message in task["conversation"]["messages"]: - typer.echo(_render_message(message)) + case "assistant_reply": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": gen_random_text(), + "user": USER, + }, + ) + tasks.append(new_task) - typer.echo("Label the following reply:") - typer.echo(task["reply"]) - message_id = _random_message_id() - user_message_id = _random_message_id() - _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) - valid_labels = task["valid_labels"] - mandatory_labels = task["mandatory_labels"] + case "rank_prompter_replies" | "rank_assistant_replies": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + ranking = gen_random_ranking(task["replies"]) + print(ranking) + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "message_ranking", + "message_id": message_id, + "task_id": task["id"], + "ranking": ranking, + "user": USER, + }, + ) + tasks.append(new_task) - labels_dict = None - if task["mode"] == "simple" and len(valid_labels) == 1: - answer = random.choice([True, False]) - labels_dict = {valid_labels[0]: 1 if answer else 0} - else: - labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) - for l in mandatory_labels: - if l not in labels: - labels.append(l) - labels_dict = {label: random.random() for label in valid_labels} - if random.random() < 0.9: - labels_dict["spam"] = 0 - labels_dict["lang_mismatch"] = 0 + case "rank_initial_prompts": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + ranking = gen_random_ranking(task["prompots"]) + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "message_ranking", + "message_id": message_id, + "ranking": ranking, + "user": USER, + }, + ) + tasks.append(new_task) - # send interaction - new_task = _post( - "/api/v1/tasks/interaction", - { - "type": "text_labels", - "message_id": task["message_id"], - "task_id": task["id"], - "text": task["reply"], - "labels": labels_dict, - "user": USER, - }, - ) - tasks.append(new_task) - case "task_done": - typer.echo("Task done!") - # rerun with new task selected from above cases - # add a new task - q += 1 - if q == 10: + case "label_prompter_reply" | "label_assistant_reply": + # acknowledge task + typer.echo("Here is the conversation so far:") + for message in task["conversation"]["messages"]: + typer.echo(_render_message(message)) + + typer.echo("Label the following reply:") + typer.echo(task["reply"]) + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + valid_labels = task["valid_labels"] + mandatory_labels = task["mandatory_labels"] + + labels_dict = None + if task["mode"] == "simple" and len(valid_labels) == 1: + answer = random.choice([True, False]) + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) + for l in mandatory_labels: + if l not in labels: + labels.append(l) + labels_dict = {label: random.random() for label in valid_labels} + if random.random() < 0.9: + labels_dict["spam"] = 0 + labels_dict["lang_mismatch"] = 0 + + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "task_id": task["id"], + "text": task["reply"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + case "task_done": typer.echo("Task done!") - break - tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] - # - case _: - typer.echo(f"Unknown task type {task['type']}") - # rerun with new task selected from above cases + # rerun with new task selected from above cases + # add a new task + q += 1 + if q == task_per_user: + typer.echo("Task done!") + break + tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] + # + case _: + typer.echo(f"Unknown task type {task['type']}") + # rerun with new task selected from above cases if __name__ == "__main__": diff --git a/text-frontend/requirements.txt b/text-frontend/requirements.txt index 3904ecb5..fe1ce30a 100644 --- a/text-frontend/requirements.txt +++ b/text-frontend/requirements.txt @@ -1,2 +1,3 @@ +faker==16.6.1 requests==2.28.1 typer==0.7.0 diff --git a/website/src/components/Dashboard/LeaderboardWidget.tsx b/website/src/components/Dashboard/LeaderboardWidget.tsx index 38fe6207..3e40e5ee 100644 --- a/website/src/components/Dashboard/LeaderboardWidget.tsx +++ b/website/src/components/Dashboard/LeaderboardWidget.tsx @@ -1,6 +1,6 @@ import { Card, CardBody, Link, Text } from "@chakra-ui/react"; -import { useTranslation } from "next-i18next"; import NextLink from "next/link"; +import { useTranslation } from "next-i18next"; import { LeaderboardTable } from "src/components/LeaderboardTable"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; @@ -19,7 +19,7 @@ export function LeaderboardWidget() { - + diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index c724ada5..f25f30dd 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -23,7 +23,7 @@ import { Tr, useDisclosure, } from "@chakra-ui/react"; -import { ColumnDef, flexRender, getCoreRowModel, Row, useReactTable } from "@tanstack/react-table"; +import { Cell, ColumnDef, flexRender, getCoreRowModel, Row, useReactTable } from "@tanstack/react-table"; import { Filter } from "lucide-react"; import { useTranslation } from "next-i18next"; import { ChangeEvent, ReactNode } from "react"; @@ -31,6 +31,7 @@ import { useDebouncedCallback } from "use-debounce"; export type DataTableColumnDef = ColumnDef & { filterable?: boolean; + span?: number | ((cell: Cell) => number | undefined); }; // TODO: stricter type @@ -126,9 +127,7 @@ export const DataTable = ({ const props = typeof rowProps === "function" ? rowProps(row) : rowProps; return ( - {row.getVisibleCells().map((cell) => ( - {flexRender(cell.column.columnDef.cell, cell.getContext())} - ))} + ); })} @@ -139,6 +138,36 @@ export const DataTable = ({ ); }; +type WithSpanCell = Cell & { span?: number }; + +const DataTableRow = ({ row }: { row: Row }) => { + const cells: WithSpanCell[] = row.getVisibleCells(); + const renderCells: WithSpanCell[] = []; + + for (let i = 0; i < cells.length; i++) { + const cell = cells[i]; + const span = (cell.column.columnDef as DataTableColumnDef).span; + const spanValue = typeof span === "function" ? span(cell) : span; + if (spanValue && spanValue > 1) { + i += spanValue - 1; // skip next `spanValue - 1` cell + } + cell.span = spanValue; + renderCells.push(cell); + } + + return ( + <> + {renderCells.map((cell) => { + return ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ); + })} + + ); +}; + const FilterModal = ({ label, onChange, diff --git a/website/src/components/LeaderboardTable/LeaderboardTable.tsx b/website/src/components/LeaderboardTable/LeaderboardTable.tsx index 6314080c..b775be8d 100644 --- a/website/src/components/LeaderboardTable/LeaderboardTable.tsx +++ b/website/src/components/LeaderboardTable/LeaderboardTable.tsx @@ -1,5 +1,6 @@ -import { CircularProgress, useColorModeValue, useToken } from "@chakra-ui/react"; +import { Box, CircularProgress, Flex, useColorModeValue, useToken } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; +import { MoreHorizontal } from "lucide-react"; import { useTranslation } from "next-i18next"; import React, { useCallback, useMemo, useState } from "react"; import { get } from "src/lib/api"; @@ -7,9 +8,11 @@ import { colors } from "src/styles/Theme/colors"; import { LeaderboardEntity, LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import useSWRImmutable from "swr/immutable"; -import { DataTable, DataTableRowPropsCallback } from "../DataTable"; +import { DataTable, DataTableColumnDef, DataTableRowPropsCallback } from "../DataTable"; -const columnHelper = createColumnHelper(); +type WindowLeaderboardEntity = LeaderboardEntity & { isSpaceRow?: boolean }; + +const columnHelper = createColumnHelper(); /** * Presents a grid of leaderboard entries with more detailed information. @@ -18,10 +21,12 @@ export const LeaderboardTable = ({ timeFrame, limit: limit, rowPerPage, + hideCurrentUserRanking, }: { timeFrame: LeaderboardTimeFrame; limit: number; rowPerPage: number; + hideCurrentUserRanking?: boolean; }) => { const { t } = useTranslation("leaderboard"); @@ -29,15 +34,19 @@ export const LeaderboardTable = ({ data: reply, isLoading, error, - } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}&limit=${limit}`, get, { - revalidateOnMount: true, - }); - - const columns = useMemo( + } = useSWRImmutable( + `/api/leaderboard?time_frame=${timeFrame}&limit=${limit}&includeUserStats=${!hideCurrentUserRanking}`, + get + ); + const columns: DataTableColumnDef[] = useMemo( () => [ - columnHelper.accessor("rank", { - header: t("rank"), - }), + { + ...columnHelper.accessor("rank", { + header: t("rank"), + cell: ({ row, getValue }) => (row.original.isSpaceRow ? : getValue()), + }), + span: (cell) => (cell.row.original.isSpaceRow ? 6 : undefined), + }, columnHelper.accessor("display_name", { header: t("user"), }), @@ -63,15 +72,72 @@ export const LeaderboardTable = ({ }, [t, reply?.last_updated]); const [page, setPage] = useState(1); - const data = useMemo(() => { + const data: WindowLeaderboardEntity[] = useMemo(() => { + if (!reply) { + return []; + } const start = (page - 1) * rowPerPage; - return reply?.leaderboard.slice(start, start + rowPerPage) || []; - }, [rowPerPage, page, reply?.leaderboard]); + const end = start + rowPerPage; + const leaderBoardEntities = reply.leaderboard.slice(start, end); + if (hideCurrentUserRanking) { + return leaderBoardEntities; + } + const userStatsWindow: WindowLeaderboardEntity[] = reply.user_stats_window; + const userStats = userStatsWindow.find((stats) => stats.highlighted); + if (userStats.rank > end) { + leaderBoardEntities.push( + { isSpaceRow: true } as WindowLeaderboardEntity, + ...reply.user_stats_window.filter( + (stats) => + leaderBoardEntities.findIndex((leaderBoardEntity) => leaderBoardEntity.user_id === stats.user_id) === -1 + ) // filter to avoid duplicated row + ); + } + return leaderBoardEntities; + }, [page, rowPerPage, reply, hideCurrentUserRanking]); + const rowProps = useLeaderboardRowProps(); + + if (isLoading) { + return ; + } + + if (error) { + return Unable to load leaderboard; + } + + const maxPage = Math.ceil(reply.leaderboard.length / rowPerPage); + + return ( + = maxPage} + disablePrevious={page === 1} + onNextClick={() => setPage((p) => p + 1)} + onPreviousClick={() => setPage((p) => p - 1)} + rowProps={rowProps} + > + ); +}; + +const SpaceRow = () => { + const color = useColorModeValue("gray.600", "gray.400"); + return ( + + + + ); +}; + +const useLeaderboardRowProps = () => { const borderColor = useToken("colors", useColorModeValue(colors.light.active, colors.dark.active)); - const rowProps = useCallback>( + return useCallback>( (row) => { - return row.original.highlighted + const rowData = row.original; + return rowData.highlighted ? { sx: { // https://stackoverflow.com/questions/37963524/how-to-apply-border-radius-to-tr-in-bootstrap @@ -93,28 +159,4 @@ export const LeaderboardTable = ({ }, [borderColor] ); - - if (isLoading) { - return ; - } - - if (error) { - return Unable to load leaderboard; - } - - const maxPage = Math.ceil(reply.leaderboard.length / rowPerPage); - - return ( - setPage((p) => p + 1)} - onPreviousClick={() => setPage((p) => p - 1)} - rowProps={rowProps} - > - ); }; diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts index 27b1b811..c35ea87f 100644 --- a/website/src/lib/api.ts +++ b/website/src/lib/api.ts @@ -25,7 +25,13 @@ api.interceptors.response.use( (response) => response, (error) => { const err = error?.response?.data; - throw new OasstError(err?.message ?? error, err?.errorCode, error?.response?.httpStatusCode || -1); + throw new OasstError({ + message: err?.message ?? error, + errorCode: err?.errorCode, + httpStatusCode: error?.response?.httpStatusCode || -1, + method: err?.config?.method, + path: err?.config?.url, + }); } ); diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 2aafab37..abd1173f 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -7,11 +7,27 @@ export class OasstError { message: string; errorCode: number; httpStatusCode: number; + path: string; + method: string; - constructor(message: string, errorCode: number, httpStatusCode: number) { + constructor({ + errorCode, + httpStatusCode, + message, + path, + method, + }: { + message: string; + errorCode: number; + httpStatusCode: number; + path: string; + method: string; + }) { this.message = message; this.errorCode = errorCode; this.httpStatusCode = httpStatusCode; + this.path = path; + this.method = method; } toString() { @@ -60,9 +76,21 @@ export class OasstApiClient { try { error = JSON.parse(errorText); } catch (e) { - throw new OasstError(errorText, 0, resp.status); + throw new OasstError({ + message: errorText, + errorCode: 0, + httpStatusCode: resp.status, + path, + method, + }); } - throw new OasstError(error.message ?? error, error.error_code, resp.status); + throw new OasstError({ + message: error.message ?? error, + errorCode: error.error_code, + httpStatusCode: resp.status, + path, + method, + }); } return resp.json(); @@ -297,9 +325,9 @@ export class OasstApiClient { return this.get(`/api/v1/messages/${messageId}/conversation`); } - async fetch_tos_acceptance(user: BackendUserCore): Promise { - const backendUser = await this.get(`/api/v1/frontend_users/${user.auth_method}/${user.id}`); - return backendUser.tos_acceptance_date; + async fetch_tos_acceptance(backendUserCore: BackendUserCore): Promise { + const user = await this.fetch_frontend_user(backendUserCore); + return user.tos_acceptance_date; } async set_tos_acceptance(user: BackendUserCore) { @@ -312,4 +340,14 @@ export class OasstApiClient { const backendUser = await this.get(`/api/v1/frontend_users/${user.auth_method}/${user.id}`); return this.get(`/api/v1/users/${backendUser.user_id}/stats`); } + + fetch_user_stats_window(user_id: string, time_frame: LeaderboardTimeFrame, window_size?: number) { + return this.get(`/api/v1/users/${user_id}/stats/${time_frame}/window`, { + window_size, + }); + } + + fetch_frontend_user(user: BackendUserCore) { + return this.get(`/api/v1/frontend_users/${user.auth_method}/${user.id}`); + } } diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index fad1d8a6..42f74c3d 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -1,5 +1,6 @@ import { withoutRole } from "src/lib/auth"; import { createApiClient } from "src/lib/oasst_client_factory"; +import { getBackendUserCore } from "src/lib/users"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; /** @@ -7,9 +8,29 @@ import { LeaderboardTimeFrame } from "src/types/Leaderboard"; */ const handler = withoutRole("banned", async (req, res, token) => { const oasstApiClient = await createApiClient(token); + const backendUser = await getBackendUserCore(token.sub); const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; - const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number }); - res.status(200).json(info); + const includeUserStats = req.query.includeUserStats; + + if (includeUserStats !== "true") { + const leaderboard = await oasstApiClient.fetch_leaderboard(time_frame, { + limit: req.query.limit as unknown as number, + }); + return res.status(200).json(leaderboard); + } + const user = await oasstApiClient.fetch_frontend_user(backendUser); + + const [leaderboard, user_stats] = await Promise.all([ + oasstApiClient.fetch_leaderboard(time_frame, { + limit: req.query.limit as unknown as number, + }), + oasstApiClient.fetch_user_stats_window(user.user_id, time_frame, 3), + ]); + + res.status(200).json({ + ...leaderboard, + user_stats_window: user_stats.leaderboard.map((stats) => ({ ...stats, is_window: true })), + }); }); export default handler;