From cdb9f2da4e89a562531d5a89576f18263de1a970 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 18:18:17 +0100 Subject: [PATCH 1/8] Send `x-oasst-user` header --- website/src/lib/auth.ts | 4 ++-- website/src/lib/oasst_api_client.ts | 24 ++++++++++++++----- website/src/pages/api/admin/status.ts | 6 ++--- website/src/pages/api/admin/update_user.ts | 13 ++++------ website/src/pages/api/admin/user_messages.ts | 5 ++-- website/src/pages/api/admin/users.ts | 5 ++-- website/src/pages/api/available_tasks.ts | 3 ++- website/src/pages/api/leaderboard.ts | 5 ++-- website/src/pages/api/messages/index.ts | 2 ++ website/src/pages/api/messages/user.ts | 3 ++- website/src/pages/api/new_task/[task_type].ts | 3 ++- website/src/pages/api/reject_task.ts | 14 ++++++----- website/src/pages/api/set_label.ts | 1 + website/src/pages/api/update_task.ts | 19 +++++++++------ website/src/pages/api/valid_labels.ts | 7 +++--- 15 files changed, 70 insertions(+), 44 deletions(-) diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts index 42c6cf79..ee004ba9 100644 --- a/website/src/lib/auth.ts +++ b/website/src/lib/auth.ts @@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe * Wraps any API Route handler and verifies that the user has the appropriate * role before running the handler. Returns a 403 otherwise. */ -const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { +const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role !== role) { res.status(403).end(); return; } - return handler(req, res); + return handler(req, res, token); }; }; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 141e325c..f745f4cb 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,8 +1,11 @@ +import { JWT } from "next-auth/jwt"; import type { EmojiOp, Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import type { AvailableTasks } from "src/types/Task"; import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users"; +import { getBackendUserCore } from "./users"; + export class OasstError { message: string; errorCode: number; @@ -18,10 +21,16 @@ export class OasstError { export class OasstApiClient { oasstApiUrl: string; oasstApiKey: string; + userHeaders: Record = {}; - constructor(oasstApiUrl: string, oasstApiKey: string) { + constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) { this.oasstApiUrl = oasstApiUrl; this.oasstApiKey = oasstApiKey; + if (user) { + this.userHeaders = { + "X-OASST-USER": `${user.auth_method}:${user.id}`, + }; + } } // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. @@ -215,9 +224,10 @@ export class OasstApiClient { method, ...init, headers: { + ...init?.headers, + ...this.userHeaders, "X-API-Key": this.oasstApiKey, "Content-Type": "application/json", - ...init?.headers, }, }); @@ -227,8 +237,7 @@ export class OasstApiClient { if (resp.status >= 300) { const errorText = await resp.text(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let error: any; + let error; try { error = JSON.parse(errorText); } catch (e) { @@ -241,6 +250,9 @@ export class OasstApiClient { } } -const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); +export const createApiClientFromUser = (user: BackendUserCore) => + new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user); -export { oasstApiClient }; +export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub)); + +export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/pages/api/admin/status.ts b/website/src/pages/api/admin/status.ts index 1da03da8..b496b813 100644 --- a/website/src/pages/api/admin/status.ts +++ b/website/src/pages/api/admin/status.ts @@ -1,17 +1,17 @@ -import { getToken } from "next-auth/jwt"; import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; -import { getBackendUserCore } from "src/lib/users"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; /** * Returns tasks availability, stats, and tree manager stats. */ const handler = withRole("admin", async (req, res) => { + // NOTE: why are we using a dummy user here? const dummyUser = { id: "__dummy_user__", display_name: "Dummy User", auth_method: "local", }; + const oasstApiClient = createApiClientFromUser(dummyUser); const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([ oasstApiClient.fetch_tasks_availability(dummyUser), oasstApiClient.fetch_stats(), diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts index 341ec736..ca44b7b5 100644 --- a/website/src/pages/api/admin/update_user.ts +++ b/website/src/pages/api/admin/update_user.ts @@ -1,22 +1,19 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; /** * Update's the user's data in the database. Accessible only to admins. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { id, auth_method, user_id, notes, role } = req.body; + const oasstApiClient = await createApiClient(token); // If the user is authorized by the web, update their role. if (auth_method === "local") { await prisma.user.update({ - where: { - id, - }, - data: { - role, - }, + where: { id }, + data: { role }, }); } // Tell the backend the user's enabled or not enabled status. diff --git a/website/src/pages/api/admin/user_messages.ts b/website/src/pages/api/admin/user_messages.ts index 236afa2d..fdde2e5f 100644 --- a/website/src/pages/api/admin/user_messages.ts +++ b/website/src/pages/api/admin/user_messages.ts @@ -1,12 +1,13 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import type { Message } from "src/types/Conversation"; /** * Returns the messages recorded by the backend for a user. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { user } = req.query; + const oasstApiClient = await createApiClient(token); const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index d10c91b0..e6f05d61 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { FetchUsersParams } from "src/types/Users"; @@ -17,9 +17,10 @@ const PAGE_SIZE = 20; * - `direction`: Either "forward" or "backward" representing the pagination * direction. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query; + const oasstApiClient = await createApiClient(token); // First, get all the users according to the backend. const { items: all_users, ...rest } = await oasstApiClient.fetch_users({ searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"], diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts index 79dcc8b9..74630dd6 100644 --- a/website/src/pages/api/available_tasks.ts +++ b/website/src/pages/api/available_tasks.ts @@ -1,9 +1,10 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); const userLanguage = getUserLanguage(req); const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage); res.status(200).json(availableTasks); diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index be91e7b4..fc07a3b2 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { + const oasstApiClient = await createApiClient(token); const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number }); res.status(200).json(info); diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index cb9728fe..fdd01313 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,6 +1,8 @@ import { withoutRole } from "src/lib/auth"; const handler = withoutRole("banned", async (req, res) => { + // TODO: move to oasst_api_client + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, { method: "GET", headers: { diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index 6f39aad1..3f2e739f 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -2,13 +2,14 @@ import { withoutRole } from "src/lib/auth"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { - //TODO: add params if needed const user = await getBackendUserCore(token.sub); const params = new URLSearchParams({ username: user.id, auth_method: user.auth_method, }); + // TODO: move to oasst_api_client + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { method: "GET", headers: { diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 360b8faa..cd964312 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const userLanguage = getUserLanguage(req); const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); let task; try { task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage); diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts index 0084fffa..190b94ba 100644 --- a/website/src/pages/api/reject_task.ts +++ b/website/src/pages/api/reject_task.ts @@ -1,19 +1,21 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, reason } = req.body; - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); + const [oasstApiClient, registeredTask] = await Promise.all([ + createApiClient(token), + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + ]); - const task = registeredTask.task as Prisma.JsonObject; - const id = task.id as string; + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; // Update the backend with the rejection - await oasstApiClient.nackTask(id, reason); + await oasstApiClient.nackTask(taskId, reason); // Send the results to the client. res.status(200).json({}); diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts index f72b2808..3c54a89f 100644 --- a/website/src/pages/api/set_label.ts +++ b/website/src/pages/api/set_label.ts @@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth"; * */ const handler = withoutRole("banned", async (req, res, token) => { + // TODO: move to oasst_api_client // Parse out the local message_id, and the interaction contents. const { message_id, label_map } = req.body; diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 6f08d640..524e6b7a 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -1,6 +1,6 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, content, update_type } = req.body; - // Record that the user has done meaningful work and is no longer new. - await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }); + // do in parallel since they are independent + const [_, registeredTask, oasstApiClient] = await Promise.all([ + // Record that the user has done meaningful work and is no longer new. + prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }), + // Accept the task so that we can complete it, this will probably go away soon. + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + // Create client for upcoming requests + createApiClient(token), + ]); + + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; - // Accept the task so that we can complete it, this will probably go away soon. - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); - const task = registeredTask.task as Prisma.JsonObject; - const taskId = task.id as string; await oasstApiClient.ackTask(taskId, registeredTask.id); // Log the interaction locally to create our user_post_id needed by the Task diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts index 5f61ff2f..0fa54e5b 100644 --- a/website/src/pages/api/valid_labels.ts +++ b/website/src/pages/api/valid_labels.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { - const valid_labels = await oasstApiClient.fetch_valid_text(); +const handler = withoutRole("banned", async (req, res, token) => { + const client = await createApiClient(token); + const valid_labels = await client.fetch_valid_text(); res.status(200).json(valid_labels); }); From 59531c7d3d596d69b5778a18ea21e2ded42b0797 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 21:03:27 +0100 Subject: [PATCH 2/8] userless client --- website/src/pages/admin/manage_user/[id].tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index b53bb7c0..4a042c52 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout"; import { Role, RoleSelect } from "src/components/RoleSelect"; import { UserMessagesCell } from "src/components/UserMessagesCell"; import { post } from "src/lib/api"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { userlessApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import useSWRMutation from "swr/mutation"; @@ -113,7 +113,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType Date: Fri, 27 Jan 2023 21:18:20 +0100 Subject: [PATCH 3/8] Move to factory --- website/src/lib/oasst_api_client.ts | 10 ---------- website/src/lib/oasst_client_factory.ts | 11 +++++++++++ website/src/lib/prismadb.ts | 2 +- website/src/pages/api/admin/status.ts | 2 +- website/src/pages/api/admin/update_user.ts | 2 +- website/src/pages/api/admin/user_messages.ts | 2 +- website/src/pages/api/admin/users.ts | 2 +- website/src/pages/api/available_tasks.ts | 2 +- website/src/pages/api/leaderboard.ts | 2 +- website/src/pages/api/new_task/[task_type].ts | 2 +- website/src/pages/api/reject_task.ts | 2 +- website/src/pages/api/report.ts | 3 ++- website/src/pages/api/update_task.ts | 2 +- website/src/pages/api/valid_labels.ts | 2 +- 14 files changed, 24 insertions(+), 22 deletions(-) create mode 100644 website/src/lib/oasst_client_factory.ts diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index f745f4cb..d7ca8281 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,11 +1,8 @@ -import { JWT } from "next-auth/jwt"; import type { EmojiOp, Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import type { AvailableTasks } from "src/types/Task"; import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users"; -import { getBackendUserCore } from "./users"; - export class OasstError { message: string; errorCode: number; @@ -249,10 +246,3 @@ export class OasstApiClient { return await resp.json(); } } - -export const createApiClientFromUser = (user: BackendUserCore) => - new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user); - -export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub)); - -export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/lib/oasst_client_factory.ts b/website/src/lib/oasst_client_factory.ts new file mode 100644 index 00000000..9f9bd657 --- /dev/null +++ b/website/src/lib/oasst_client_factory.ts @@ -0,0 +1,11 @@ +import { JWT } from "next-auth/jwt"; +import { OasstApiClient } from "src/lib/oasst_api_client"; +import { getBackendUserCore } from "src/lib/users"; +import { BackendUserCore } from "src/types/Users"; + +export const createApiClientFromUser = (user: BackendUserCore) => + new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user); + +export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub)); + +export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/lib/prismadb.ts b/website/src/lib/prismadb.ts index 296eda8b..336d782e 100644 --- a/website/src/lib/prismadb.ts +++ b/website/src/lib/prismadb.ts @@ -3,7 +3,7 @@ declare global { // eslint-disable-next-line no-var var prisma: PrismaClient | undefined; } - +console.trace() const client = globalThis.prisma || new PrismaClient(); if (process.env.NODE_ENV !== "production") { globalThis.prisma = client; diff --git a/website/src/pages/api/admin/status.ts b/website/src/pages/api/admin/status.ts index b496b813..956ad2cb 100644 --- a/website/src/pages/api/admin/status.ts +++ b/website/src/pages/api/admin/status.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { createApiClientFromUser } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; /** * Returns tasks availability, stats, and tree manager stats. diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts index ca44b7b5..c71159ad 100644 --- a/website/src/pages/api/admin/update_user.ts +++ b/website/src/pages/api/admin/update_user.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; /** diff --git a/website/src/pages/api/admin/user_messages.ts b/website/src/pages/api/admin/user_messages.ts index fdde2e5f..0223e8e3 100644 --- a/website/src/pages/api/admin/user_messages.ts +++ b/website/src/pages/api/admin/user_messages.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import type { Message } from "src/types/Conversation"; /** diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index e6f05d61..eae0f072 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { FetchUsersParams } from "src/types/Users"; diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts index 74630dd6..218e3864 100644 --- a/website/src/pages/api/available_tasks.ts +++ b/website/src/pages/api/available_tasks.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { createApiClientFromUser } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index fc07a3b2..fad1d8a6 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; /** diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index cd964312..34993f4f 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { createApiClientFromUser } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts index 190b94ba..2e3f4fa1 100644 --- a/website/src/pages/api/reject_task.ts +++ b/website/src/pages/api/reject_task.ts @@ -1,6 +1,6 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; const handler = withoutRole("banned", async (req, res, token) => { diff --git a/website/src/pages/api/report.ts b/website/src/pages/api/report.ts index 36252cad..9b904df4 100644 --- a/website/src/pages/api/report.ts +++ b/website/src/pages/api/report.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; /** @@ -11,6 +11,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const { message_id, text } = req.body; const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); try { await oasstApiClient.send_report(message_id, user, text); } catch (err) { diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 524e6b7a..1b4f2eda 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -1,6 +1,6 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts index 0fa54e5b..dca92d90 100644 --- a/website/src/pages/api/valid_labels.ts +++ b/website/src/pages/api/valid_labels.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { createApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; /** * Returns the set of valid labels that can be applied to messages. From 3c791efb7996496a446aee66694221be1cb53449 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 21:59:29 +0100 Subject: [PATCH 4/8] Move api calls to `oasst_api_client` --- website/src/lib/oasst_api_client.ts | 20 +++++++++++++ website/src/lib/prismadb.ts | 2 +- website/src/pages/admin/manage_user/[id].tsx | 2 +- .../src/pages/api/messages/[id]/children.ts | 16 +++-------- .../pages/api/messages/[id]/conversation.ts | 16 +++-------- website/src/pages/api/messages/[id]/emoji.ts | 3 +- website/src/pages/api/messages/[id]/index.ts | 19 ++----------- website/src/pages/api/messages/[id]/parent.ts | 28 +++++-------------- website/src/pages/api/messages/index.ts | 16 +++-------- website/src/pages/api/messages/user.ts | 19 ++----------- website/src/types/Conversation.ts | 5 ++-- 11 files changed, 52 insertions(+), 94 deletions(-) diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index d7ca8281..b9a9489e 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -245,4 +245,24 @@ export class OasstApiClient { return await resp.json(); } + + fetch_my_messages(user: BackendUserCore) { + const params = new URLSearchParams({ + username: user.id, + auth_method: user.auth_method, + }); + return this.get(`/api/v1/messages?${params}`); + } + + fetch_recent_messages() { + return this.get(`/api/v1/messages`); + } + + fetch_message_children(messageId: string) { + return this.get(`/api/v1/messages/${messageId}/children`); + } + + fetch_conversation(messageId: string) { + return this.get(`/api/v1/messages/${messageId}/conversation`); + } } diff --git a/website/src/lib/prismadb.ts b/website/src/lib/prismadb.ts index 336d782e..296eda8b 100644 --- a/website/src/lib/prismadb.ts +++ b/website/src/lib/prismadb.ts @@ -3,7 +3,7 @@ declare global { // eslint-disable-next-line no-var var prisma: PrismaClient | undefined; } -console.trace() + const client = globalThis.prisma || new PrismaClient(); if (process.env.NODE_ENV !== "production") { globalThis.prisma = client; diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index 4a042c52..a68bca16 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout"; import { Role, RoleSelect } from "src/components/RoleSelect"; import { UserMessagesCell } from "src/components/UserMessagesCell"; import { post } from "src/lib/api"; -import { userlessApiClient } from "src/lib/oasst_api_client"; +import { userlessApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import useSWRMutation from "swr/mutation"; diff --git a/website/src/pages/api/messages/[id]/children.ts b/website/src/pages/api/messages/[id]/children.ts index 4a184c11..0185b615 100644 --- a/website/src/pages/api/messages/[id]/children.ts +++ b/website/src/pages/api/messages/[id]/children.ts @@ -1,18 +1,10 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = await createApiClient(token); + const messages = await client.fetch_message_children(id as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/[id]/conversation.ts b/website/src/pages/api/messages/[id]/conversation.ts index 0c401883..fc284e40 100644 --- a/website/src/pages/api/messages/[id]/conversation.ts +++ b/website/src/pages/api/messages/[id]/conversation.ts @@ -1,18 +1,10 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = await createApiClient(token); + const messages = await client.fetch_conversation(id as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/[id]/emoji.ts b/website/src/pages/api/messages/[id]/emoji.ts index 59d93f69..8d6257c1 100644 --- a/website/src/pages/api/messages/[id]/emoji.ts +++ b/website/src/pages/api/messages/[id]/emoji.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { @@ -15,6 +15,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const { emoji, op } = req.body; const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); try { await oasstApiClient.set_user_message_emoji(messageId, user, emoji, op); } catch (err) { diff --git a/website/src/pages/api/messages/[id]/index.ts b/website/src/pages/api/messages/[id]/index.ts index 28947feb..b3361a2d 100644 --- a/website/src/pages/api/messages/[id]/index.ts +++ b/website/src/pages/api/messages/[id]/index.ts @@ -1,25 +1,12 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - const user = await getBackendUserCore(token.sub); - const params = new URLSearchParams({ - username: user.id, - auth_method: user.auth_method, - }); - - const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/?${params}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const message = await messageRes.json(); - - // Send recieved messages to the client. + const client = createApiClientFromUser(user); + const message = await client.fetch_message(id as string, user); res.status(200).json(message); }); diff --git a/website/src/pages/api/messages/[id]/parent.ts b/website/src/pages/api/messages/[id]/parent.ts index ba332be9..9b8c64d4 100644 --- a/website/src/pages/api/messages/[id]/parent.ts +++ b/website/src/pages/api/messages/[id]/parent.ts @@ -1,6 +1,8 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient, createApiClientFromUser } from "src/lib/oasst_client_factory"; +import { getBackendUserCore } from "src/lib/users"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; if (!id) { @@ -8,32 +10,16 @@ const handler = withoutRole("banned", async (req, res) => { return; } - const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - - const message = await messageRes.json(); + const user = await getBackendUserCore(token.sub); + const client = createApiClientFromUser(user); + const message = await client.fetch_message(id as string, user); if (!message.parent_id) { res.status(404).end(); return; } - const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - - const parent = await parentRes.json(); - - // Send recieved messages to the client. + const parent = await client.fetch_message(message.parent_id, user); res.status(200).json(parent); }); diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index fdd01313..fbcaee3c 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,17 +1,9 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { - // TODO: move to oasst_api_client - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. +const handler = withoutRole("banned", async (req, res, token) => { + const client = await createApiClient(token); + const messages = await client.fetch_recent_messages(); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index 3f2e739f..bffe2fb6 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -1,24 +1,11 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const user = await getBackendUserCore(token.sub); - const params = new URLSearchParams({ - username: user.id, - auth_method: user.auth_method, - }); - - // TODO: move to oasst_api_client - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = createApiClientFromUser(user); + const messages = await client.fetch_my_messages(user); res.status(200).json(messages); }); diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts index 9cc3a140..8b258a25 100644 --- a/website/src/types/Conversation.ts +++ b/website/src/types/Conversation.ts @@ -11,11 +11,12 @@ export interface MessageEmojis { } export interface Message extends MessageEmojis { + id: string; text: string; is_assistant: boolean; - id: string; - created_date: string; // iso date string lang: string; + created_date: string; // iso date string + parent_id: string; frontend_message_id?: string; } From 0a6d9011eaae4fbdb960eedbf28632a80d176f88 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 22:29:14 +0100 Subject: [PATCH 5/8] Fix semantic html nesting --- website/src/components/Survey/LabelInputGroup.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelInputGroup.tsx index 9a06b6ca..94fcc48e 100644 --- a/website/src/components/Survey/LabelInputGroup.tsx +++ b/website/src/components/Survey/LabelInputGroup.tsx @@ -211,7 +211,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label }} alignItems="center" > - + {textA} {descriptionA.length > 0 ? : null} @@ -229,7 +229,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label /> - + {textB} {descriptionB.length > 0 ? : null} From ab227c5db5e5331c3ea291a29cee500b83a104b6 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 22:30:11 +0100 Subject: [PATCH 6/8] Fix error in labeling tasks --- .../components/Messages/MessageTableEntry.tsx | 10 +++--- .../components/Tasks/LabelTask/LabelTask.tsx | 34 ++++++++++++------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 3cde48f6..7202903a 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -34,7 +34,7 @@ interface MessageTableEntryProps { export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) { const router = useRouter(); - const [emojis, setEmojis] = useState({ emojis: {}, user_emojis: [] }); + const [emojiState, setEmojis] = useState({ emojis: {}, user_emojis: [] }); useEffect(() => { setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis }); }, [message.emojis, message.user_emojis]); @@ -93,17 +93,17 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }} onClick={(e) => e.stopPropagation()} > - {Object.entries(emojis.emojis).map(([emoji, count]) => ( + {Object.entries(emojiState.emojis).map(([emoji, count]) => ( react(emoji, !emojis.user_emojis.includes(emoji))} + checked={emojiState.user_emojis.includes(emoji)} + onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))} /> ))} ; message_id: string }>) => { + const { i18n } = useTranslation(); const [sliderValues, setSliderValues] = useState(new Array(task.valid_labels.length).fill(null)); useEffect(() => { @@ -27,6 +30,23 @@ export const LabelTask = ({ const cardColor = useColorModeValue("gray.50", "gray.800"); const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam"; + // TODO: remove as soon as the backend delivers + // real information about the current message + const additionMessage: Message = useMemo( + () => ({ + text: task.reply, + is_assistant: task.type === TaskType.label_assistant_reply, + message_id: task.message_id, + created_date: new Date().toISOString(), + emojis: {}, + user_emojis: [], + id: "dummy", + lang: i18n.language, + parent_id: "dummy", + }), + [task.reply, task.type, task.message_id, i18n.language] + ); + return (
@@ -34,17 +54,7 @@ export const LabelTask = ({ {task.conversation ? ( - + ) : ( From e4dcfe41614f22271aadb8e18fc15b7533896be5 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 22:41:20 +0100 Subject: [PATCH 7/8] Use new reply_message --- .../components/Tasks/LabelTask/LabelTask.tsx | 22 ++++--------------- website/src/types/Tasks.ts | 4 +++- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx index 2a066c87..97a38057 100644 --- a/website/src/components/Tasks/LabelTask/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx @@ -30,23 +30,6 @@ export const LabelTask = ({ const cardColor = useColorModeValue("gray.50", "gray.800"); const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam"; - // TODO: remove as soon as the backend delivers - // real information about the current message - const additionMessage: Message = useMemo( - () => ({ - text: task.reply, - is_assistant: task.type === TaskType.label_assistant_reply, - message_id: task.message_id, - created_date: new Date().toISOString(), - emojis: {}, - user_emojis: [], - id: "dummy", - lang: i18n.language, - parent_id: "dummy", - }), - [task.reply, task.type, task.message_id, i18n.language] - ); - return (
@@ -54,7 +37,10 @@ export const LabelTask = ({ {task.conversation ? ( - + ) : ( diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts index a791916e..bbbe3a67 100644 --- a/website/src/types/Tasks.ts +++ b/website/src/types/Tasks.ts @@ -1,4 +1,4 @@ -import { Conversation } from "./Conversation"; +import { Conversation, Message } from "./Conversation"; import { BaseTask, TaskType } from "./Task"; export interface CreateInitialPromptTask extends BaseTask { @@ -37,6 +37,7 @@ export interface LabelAssistantReplyTask extends BaseTask { type: TaskType.label_assistant_reply; message_id: string; conversation: Conversation; + reply_message: Message; reply: string; valid_labels: string[]; mode: "simple" | "full"; @@ -47,6 +48,7 @@ export interface LabelPrompterReplyTask extends BaseTask { type: TaskType.label_prompter_reply; message_id: string; conversation: Conversation; + reply_message: Message; reply: string; valid_labels: string[]; mode: "simple" | "full"; From 4067bec03f1ebd3d52d90675d0c438f13fa88674 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 23:12:44 +0100 Subject: [PATCH 8/8] Remove unused i18n --- website/src/components/Tasks/LabelTask/LabelTask.tsx | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx index 97a38057..10ea76fb 100644 --- a/website/src/components/Tasks/LabelTask/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx @@ -1,14 +1,11 @@ import { Box, Button, Flex, HStack, Text, useColorModeValue } from "@chakra-ui/react"; -import { useEffect, useMemo, useState } from "react"; -import { useTranslation } from "react-i18next"; +import { useEffect, useState } from "react"; import { MessageView } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; -import { Message } from "src/types/Conversation"; -import { TaskType } from "src/types/Task"; export const LabelTask = ({ task, @@ -17,7 +14,6 @@ export const LabelTask = ({ onReplyChanged, onValidityChanged, }: TaskSurveyProps<{ text: string; labels: Record; message_id: string }>) => { - const { i18n } = useTranslation(); const [sliderValues, setSliderValues] = useState(new Array(task.valid_labels.length).fill(null)); useEffect(() => {