From cdb9f2da4e89a562531d5a89576f18263de1a970 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 27 Jan 2023 18:18:17 +0100 Subject: [PATCH] Send `x-oasst-user` header --- website/src/lib/auth.ts | 4 ++-- website/src/lib/oasst_api_client.ts | 24 ++++++++++++++----- website/src/pages/api/admin/status.ts | 6 ++--- website/src/pages/api/admin/update_user.ts | 13 ++++------ website/src/pages/api/admin/user_messages.ts | 5 ++-- website/src/pages/api/admin/users.ts | 5 ++-- website/src/pages/api/available_tasks.ts | 3 ++- website/src/pages/api/leaderboard.ts | 5 ++-- website/src/pages/api/messages/index.ts | 2 ++ website/src/pages/api/messages/user.ts | 3 ++- website/src/pages/api/new_task/[task_type].ts | 3 ++- website/src/pages/api/reject_task.ts | 14 ++++++----- website/src/pages/api/set_label.ts | 1 + website/src/pages/api/update_task.ts | 19 +++++++++------ website/src/pages/api/valid_labels.ts | 7 +++--- 15 files changed, 70 insertions(+), 44 deletions(-) diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts index 42c6cf79..ee004ba9 100644 --- a/website/src/lib/auth.ts +++ b/website/src/lib/auth.ts @@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe * Wraps any API Route handler and verifies that the user has the appropriate * role before running the handler. Returns a 403 otherwise. */ -const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { +const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role !== role) { res.status(403).end(); return; } - return handler(req, res); + return handler(req, res, token); }; }; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 141e325c..f745f4cb 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,8 +1,11 @@ +import { JWT } from "next-auth/jwt"; import type { EmojiOp, Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import type { AvailableTasks } from "src/types/Task"; import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users"; +import { getBackendUserCore } from "./users"; + export class OasstError { message: string; errorCode: number; @@ -18,10 +21,16 @@ export class OasstError { export class OasstApiClient { oasstApiUrl: string; oasstApiKey: string; + userHeaders: Record = {}; - constructor(oasstApiUrl: string, oasstApiKey: string) { + constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) { this.oasstApiUrl = oasstApiUrl; this.oasstApiKey = oasstApiKey; + if (user) { + this.userHeaders = { + "X-OASST-USER": `${user.auth_method}:${user.id}`, + }; + } } // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. @@ -215,9 +224,10 @@ export class OasstApiClient { method, ...init, headers: { + ...init?.headers, + ...this.userHeaders, "X-API-Key": this.oasstApiKey, "Content-Type": "application/json", - ...init?.headers, }, }); @@ -227,8 +237,7 @@ export class OasstApiClient { if (resp.status >= 300) { const errorText = await resp.text(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let error: any; + let error; try { error = JSON.parse(errorText); } catch (e) { @@ -241,6 +250,9 @@ export class OasstApiClient { } } -const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); +export const createApiClientFromUser = (user: BackendUserCore) => + new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user); -export { oasstApiClient }; +export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub)); + +export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/pages/api/admin/status.ts b/website/src/pages/api/admin/status.ts index 1da03da8..b496b813 100644 --- a/website/src/pages/api/admin/status.ts +++ b/website/src/pages/api/admin/status.ts @@ -1,17 +1,17 @@ -import { getToken } from "next-auth/jwt"; import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; -import { getBackendUserCore } from "src/lib/users"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; /** * Returns tasks availability, stats, and tree manager stats. */ const handler = withRole("admin", async (req, res) => { + // NOTE: why are we using a dummy user here? const dummyUser = { id: "__dummy_user__", display_name: "Dummy User", auth_method: "local", }; + const oasstApiClient = createApiClientFromUser(dummyUser); const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([ oasstApiClient.fetch_tasks_availability(dummyUser), oasstApiClient.fetch_stats(), diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts index 341ec736..ca44b7b5 100644 --- a/website/src/pages/api/admin/update_user.ts +++ b/website/src/pages/api/admin/update_user.ts @@ -1,22 +1,19 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; /** * Update's the user's data in the database. Accessible only to admins. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { id, auth_method, user_id, notes, role } = req.body; + const oasstApiClient = await createApiClient(token); // If the user is authorized by the web, update their role. if (auth_method === "local") { await prisma.user.update({ - where: { - id, - }, - data: { - role, - }, + where: { id }, + data: { role }, }); } // Tell the backend the user's enabled or not enabled status. diff --git a/website/src/pages/api/admin/user_messages.ts b/website/src/pages/api/admin/user_messages.ts index 236afa2d..fdde2e5f 100644 --- a/website/src/pages/api/admin/user_messages.ts +++ b/website/src/pages/api/admin/user_messages.ts @@ -1,12 +1,13 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import type { Message } from "src/types/Conversation"; /** * Returns the messages recorded by the backend for a user. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { user } = req.query; + const oasstApiClient = await createApiClient(token); const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index d10c91b0..e6f05d61 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { FetchUsersParams } from "src/types/Users"; @@ -17,9 +17,10 @@ const PAGE_SIZE = 20; * - `direction`: Either "forward" or "backward" representing the pagination * direction. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query; + const oasstApiClient = await createApiClient(token); // First, get all the users according to the backend. const { items: all_users, ...rest } = await oasstApiClient.fetch_users({ searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"], diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts index 79dcc8b9..74630dd6 100644 --- a/website/src/pages/api/available_tasks.ts +++ b/website/src/pages/api/available_tasks.ts @@ -1,9 +1,10 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); const userLanguage = getUserLanguage(req); const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage); res.status(200).json(availableTasks); diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index be91e7b4..fc07a3b2 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { + const oasstApiClient = await createApiClient(token); const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number }); res.status(200).json(info); diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index cb9728fe..fdd01313 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,6 +1,8 @@ import { withoutRole } from "src/lib/auth"; const handler = withoutRole("banned", async (req, res) => { + // TODO: move to oasst_api_client + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, { method: "GET", headers: { diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index 6f39aad1..3f2e739f 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -2,13 +2,14 @@ import { withoutRole } from "src/lib/auth"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { - //TODO: add params if needed const user = await getBackendUserCore(token.sub); const params = new URLSearchParams({ username: user.id, auth_method: user.auth_method, }); + // TODO: move to oasst_api_client + const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { method: "GET", headers: { diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 360b8faa..cd964312 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const userLanguage = getUserLanguage(req); const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); let task; try { task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage); diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts index 0084fffa..190b94ba 100644 --- a/website/src/pages/api/reject_task.ts +++ b/website/src/pages/api/reject_task.ts @@ -1,19 +1,21 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, reason } = req.body; - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); + const [oasstApiClient, registeredTask] = await Promise.all([ + createApiClient(token), + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + ]); - const task = registeredTask.task as Prisma.JsonObject; - const id = task.id as string; + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; // Update the backend with the rejection - await oasstApiClient.nackTask(id, reason); + await oasstApiClient.nackTask(taskId, reason); // Send the results to the client. res.status(200).json({}); diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts index f72b2808..3c54a89f 100644 --- a/website/src/pages/api/set_label.ts +++ b/website/src/pages/api/set_label.ts @@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth"; * */ const handler = withoutRole("banned", async (req, res, token) => { + // TODO: move to oasst_api_client // Parse out the local message_id, and the interaction contents. const { message_id, label_map } = req.body; diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 6f08d640..524e6b7a 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -1,6 +1,6 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, content, update_type } = req.body; - // Record that the user has done meaningful work and is no longer new. - await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }); + // do in parallel since they are independent + const [_, registeredTask, oasstApiClient] = await Promise.all([ + // Record that the user has done meaningful work and is no longer new. + prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }), + // Accept the task so that we can complete it, this will probably go away soon. + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + // Create client for upcoming requests + createApiClient(token), + ]); + + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; - // Accept the task so that we can complete it, this will probably go away soon. - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); - const task = registeredTask.task as Prisma.JsonObject; - const taskId = task.id as string; await oasstApiClient.ackTask(taskId, registeredTask.id); // Log the interaction locally to create our user_post_id needed by the Task diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts index 5f61ff2f..0fa54e5b 100644 --- a/website/src/pages/api/valid_labels.ts +++ b/website/src/pages/api/valid_labels.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_api_client"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { - const valid_labels = await oasstApiClient.fetch_valid_text(); +const handler = withoutRole("banned", async (req, res, token) => { + const client = await createApiClient(token); + const valid_labels = await client.fetch_valid_text(); res.status(200).json(valid_labels); });