diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 3cde48f6..7202903a 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -34,7 +34,7 @@ interface MessageTableEntryProps { export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) { const router = useRouter(); - const [emojis, setEmojis] = useState({ emojis: {}, user_emojis: [] }); + const [emojiState, setEmojis] = useState({ emojis: {}, user_emojis: [] }); useEffect(() => { setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis }); }, [message.emojis, message.user_emojis]); @@ -93,17 +93,17 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }} onClick={(e) => e.stopPropagation()} > - {Object.entries(emojis.emojis).map(([emoji, count]) => ( + {Object.entries(emojiState.emojis).map(([emoji, count]) => ( react(emoji, !emojis.user_emojis.includes(emoji))} + checked={emojiState.user_emojis.includes(emoji)} + onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))} /> ))} - + {textA} {descriptionA.length > 0 ? : null} @@ -229,7 +229,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label /> - + {textB} {descriptionB.length > 0 ? : null} diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx index 2ce5febc..10ea76fb 100644 --- a/website/src/components/Tasks/LabelTask/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx @@ -6,7 +6,6 @@ import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; -import { TaskType } from "src/types/Task"; export const LabelTask = ({ task, @@ -35,14 +34,7 @@ export const LabelTask = ({ {task.conversation ? ( diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts index 42c6cf79..ee004ba9 100644 --- a/website/src/lib/auth.ts +++ b/website/src/lib/auth.ts @@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe * Wraps any API Route handler and verifies that the user has the appropriate * role before running the handler. Returns a 403 otherwise. */ -const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { +const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role !== role) { res.status(403).end(); return; } - return handler(req, res); + return handler(req, res, token); }; }; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 141e325c..b9a9489e 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -18,10 +18,16 @@ export class OasstError { export class OasstApiClient { oasstApiUrl: string; oasstApiKey: string; + userHeaders: Record = {}; - constructor(oasstApiUrl: string, oasstApiKey: string) { + constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) { this.oasstApiUrl = oasstApiUrl; this.oasstApiKey = oasstApiKey; + if (user) { + this.userHeaders = { + "X-OASST-USER": `${user.auth_method}:${user.id}`, + }; + } } // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. @@ -215,9 +221,10 @@ export class OasstApiClient { method, ...init, headers: { + ...init?.headers, + ...this.userHeaders, "X-API-Key": this.oasstApiKey, "Content-Type": "application/json", - ...init?.headers, }, }); @@ -227,8 +234,7 @@ export class OasstApiClient { if (resp.status >= 300) { const errorText = await resp.text(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let error: any; + let error; try { error = JSON.parse(errorText); } catch (e) { @@ -239,8 +245,24 @@ export class OasstApiClient { return await resp.json(); } + + fetch_my_messages(user: BackendUserCore) { + const params = new URLSearchParams({ + username: user.id, + auth_method: user.auth_method, + }); + return this.get(`/api/v1/messages?${params}`); + } + + fetch_recent_messages() { + return this.get(`/api/v1/messages`); + } + + fetch_message_children(messageId: string) { + return this.get(`/api/v1/messages/${messageId}/children`); + } + + fetch_conversation(messageId: string) { + return this.get(`/api/v1/messages/${messageId}/conversation`); + } } - -const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); - -export { oasstApiClient }; diff --git a/website/src/lib/oasst_client_factory.ts b/website/src/lib/oasst_client_factory.ts new file mode 100644 index 00000000..9f9bd657 --- /dev/null +++ b/website/src/lib/oasst_client_factory.ts @@ -0,0 +1,11 @@ +import { JWT } from "next-auth/jwt"; +import { OasstApiClient } from "src/lib/oasst_api_client"; +import { getBackendUserCore } from "src/lib/users"; +import { BackendUserCore } from "src/types/Users"; + +export const createApiClientFromUser = (user: BackendUserCore) => + new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user); + +export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub)); + +export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index b53bb7c0..a68bca16 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout"; import { Role, RoleSelect } from "src/components/RoleSelect"; import { UserMessagesCell } from "src/components/UserMessagesCell"; import { post } from "src/lib/api"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { userlessApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import useSWRMutation from "swr/mutation"; @@ -113,7 +113,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType { + // NOTE: why are we using a dummy user here? const dummyUser = { id: "__dummy_user__", display_name: "Dummy User", auth_method: "local", }; + const oasstApiClient = createApiClientFromUser(dummyUser); const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([ oasstApiClient.fetch_tasks_availability(dummyUser), oasstApiClient.fetch_stats(), diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts index 341ec736..c71159ad 100644 --- a/website/src/pages/api/admin/update_user.ts +++ b/website/src/pages/api/admin/update_user.ts @@ -1,22 +1,19 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; /** * Update's the user's data in the database. Accessible only to admins. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { id, auth_method, user_id, notes, role } = req.body; + const oasstApiClient = await createApiClient(token); // If the user is authorized by the web, update their role. if (auth_method === "local") { await prisma.user.update({ - where: { - id, - }, - data: { - role, - }, + where: { id }, + data: { role }, }); } // Tell the backend the user's enabled or not enabled status. diff --git a/website/src/pages/api/admin/user_messages.ts b/website/src/pages/api/admin/user_messages.ts index 236afa2d..0223e8e3 100644 --- a/website/src/pages/api/admin/user_messages.ts +++ b/website/src/pages/api/admin/user_messages.ts @@ -1,12 +1,13 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import type { Message } from "src/types/Conversation"; /** * Returns the messages recorded by the backend for a user. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { user } = req.query; + const oasstApiClient = await createApiClient(token); const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index d10c91b0..eae0f072 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,5 +1,5 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { FetchUsersParams } from "src/types/Users"; @@ -17,9 +17,10 @@ const PAGE_SIZE = 20; * - `direction`: Either "forward" or "backward" representing the pagination * direction. */ -const handler = withRole("admin", async (req, res) => { +const handler = withRole("admin", async (req, res, token) => { const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query; + const oasstApiClient = await createApiClient(token); // First, get all the users according to the backend. const { items: all_users, ...rest } = await oasstApiClient.fetch_users({ searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"], diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts index 79dcc8b9..218e3864 100644 --- a/website/src/pages/api/available_tasks.ts +++ b/website/src/pages/api/available_tasks.ts @@ -1,9 +1,10 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); const userLanguage = getUserLanguage(req); const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage); res.status(200).json(availableTasks); diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index be91e7b4..fad1d8a6 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { + const oasstApiClient = await createApiClient(token); const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number }); res.status(200).json(info); diff --git a/website/src/pages/api/messages/[id]/children.ts b/website/src/pages/api/messages/[id]/children.ts index 4a184c11..0185b615 100644 --- a/website/src/pages/api/messages/[id]/children.ts +++ b/website/src/pages/api/messages/[id]/children.ts @@ -1,18 +1,10 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = await createApiClient(token); + const messages = await client.fetch_message_children(id as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/[id]/conversation.ts b/website/src/pages/api/messages/[id]/conversation.ts index 0c401883..fc284e40 100644 --- a/website/src/pages/api/messages/[id]/conversation.ts +++ b/website/src/pages/api/messages/[id]/conversation.ts @@ -1,18 +1,10 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = await createApiClient(token); + const messages = await client.fetch_conversation(id as string); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/[id]/emoji.ts b/website/src/pages/api/messages/[id]/emoji.ts index 59d93f69..8d6257c1 100644 --- a/website/src/pages/api/messages/[id]/emoji.ts +++ b/website/src/pages/api/messages/[id]/emoji.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { @@ -15,6 +15,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const { emoji, op } = req.body; const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); try { await oasstApiClient.set_user_message_emoji(messageId, user, emoji, op); } catch (err) { diff --git a/website/src/pages/api/messages/[id]/index.ts b/website/src/pages/api/messages/[id]/index.ts index 28947feb..b3361a2d 100644 --- a/website/src/pages/api/messages/[id]/index.ts +++ b/website/src/pages/api/messages/[id]/index.ts @@ -1,25 +1,12 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; - const user = await getBackendUserCore(token.sub); - const params = new URLSearchParams({ - username: user.id, - auth_method: user.auth_method, - }); - - const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/?${params}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - const message = await messageRes.json(); - - // Send recieved messages to the client. + const client = createApiClientFromUser(user); + const message = await client.fetch_message(id as string, user); res.status(200).json(message); }); diff --git a/website/src/pages/api/messages/[id]/parent.ts b/website/src/pages/api/messages/[id]/parent.ts index ba332be9..9b8c64d4 100644 --- a/website/src/pages/api/messages/[id]/parent.ts +++ b/website/src/pages/api/messages/[id]/parent.ts @@ -1,6 +1,8 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient, createApiClientFromUser } from "src/lib/oasst_client_factory"; +import { getBackendUserCore } from "src/lib/users"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { const { id } = req.query; if (!id) { @@ -8,32 +10,16 @@ const handler = withoutRole("banned", async (req, res) => { return; } - const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - - const message = await messageRes.json(); + const user = await getBackendUserCore(token.sub); + const client = createApiClientFromUser(user); + const message = await client.fetch_message(id as string, user); if (!message.parent_id) { res.status(404).end(); return; } - const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - "Content-Type": "application/json", - }, - }); - - const parent = await parentRes.json(); - - // Send recieved messages to the client. + const parent = await client.fetch_message(message.parent_id, user); res.status(200).json(parent); }); diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index cb9728fe..fbcaee3c 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,15 +1,9 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClient } from "src/lib/oasst_client_factory"; -const handler = withoutRole("banned", async (req, res) => { - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. +const handler = withoutRole("banned", async (req, res, token) => { + const client = await createApiClient(token); + const messages = await client.fetch_recent_messages(); res.status(200).json(messages); }); diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index 6f39aad1..bffe2fb6 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -1,23 +1,11 @@ import { withoutRole } from "src/lib/auth"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { - //TODO: add params if needed const user = await getBackendUserCore(token.sub); - const params = new URLSearchParams({ - username: user.id, - auth_method: user.auth_method, - }); - - const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { - method: "GET", - headers: { - "X-API-Key": process.env.FASTAPI_KEY, - }, - }); - const messages = await messagesRes.json(); - - // Send recieved messages to the client. + const client = createApiClientFromUser(user); + const messages = await client.fetch_my_messages(user); res.status(200).json(messages); }); diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 360b8faa..34993f4f 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const userLanguage = getUserLanguage(req); const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); let task; try { task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage); diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts index 0084fffa..2e3f4fa1 100644 --- a/website/src/pages/api/reject_task.ts +++ b/website/src/pages/api/reject_task.ts @@ -1,19 +1,21 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; -const handler = withoutRole("banned", async (req, res) => { +const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, reason } = req.body; - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); + const [oasstApiClient, registeredTask] = await Promise.all([ + createApiClient(token), + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + ]); - const task = registeredTask.task as Prisma.JsonObject; - const id = task.id as string; + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; // Update the backend with the rejection - await oasstApiClient.nackTask(id, reason); + await oasstApiClient.nackTask(taskId, reason); // Send the results to the client. res.status(200).json({}); diff --git a/website/src/pages/api/report.ts b/website/src/pages/api/report.ts index 36252cad..9b904df4 100644 --- a/website/src/pages/api/report.ts +++ b/website/src/pages/api/report.ts @@ -1,5 +1,5 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClientFromUser } from "src/lib/oasst_client_factory"; import { getBackendUserCore } from "src/lib/users"; /** @@ -11,6 +11,7 @@ const handler = withoutRole("banned", async (req, res, token) => { const { message_id, text } = req.body; const user = await getBackendUserCore(token.sub); + const oasstApiClient = createApiClientFromUser(user); try { await oasstApiClient.send_report(message_id, user, text); } catch (err) { diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts index f72b2808..3c54a89f 100644 --- a/website/src/pages/api/set_label.ts +++ b/website/src/pages/api/set_label.ts @@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth"; * */ const handler = withoutRole("banned", async (req, res, token) => { + // TODO: move to oasst_api_client // Parse out the local message_id, and the interaction contents. const { message_id, label_map } = req.body; diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 6f08d640..1b4f2eda 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -1,6 +1,6 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; import prisma from "src/lib/prismadb"; import { getBackendUserCore, getUserLanguage } from "src/lib/users"; @@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => { // Parse out the local task ID and the interaction contents. const { id: frontendId, content, update_type } = req.body; - // Record that the user has done meaningful work and is no longer new. - await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }); + // do in parallel since they are independent + const [_, registeredTask, oasstApiClient] = await Promise.all([ + // Record that the user has done meaningful work and is no longer new. + prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }), + // Accept the task so that we can complete it, this will probably go away soon. + prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }), + // Create client for upcoming requests + createApiClient(token), + ]); + + const taskId = (registeredTask.task as Prisma.JsonObject).id as string; - // Accept the task so that we can complete it, this will probably go away soon. - const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }); - const task = registeredTask.task as Prisma.JsonObject; - const taskId = task.id as string; await oasstApiClient.ackTask(taskId, registeredTask.id); // Log the interaction locally to create our user_post_id needed by the Task diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts index 5f61ff2f..dca92d90 100644 --- a/website/src/pages/api/valid_labels.ts +++ b/website/src/pages/api/valid_labels.ts @@ -1,11 +1,12 @@ import { withoutRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { createApiClient } from "src/lib/oasst_client_factory"; /** * Returns the set of valid labels that can be applied to messages. */ -const handler = withoutRole("banned", async (req, res) => { - const valid_labels = await oasstApiClient.fetch_valid_text(); +const handler = withoutRole("banned", async (req, res, token) => { + const client = await createApiClient(token); + const valid_labels = await client.fetch_valid_text(); res.status(200).json(valid_labels); }); diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts index 9cc3a140..8b258a25 100644 --- a/website/src/types/Conversation.ts +++ b/website/src/types/Conversation.ts @@ -11,11 +11,12 @@ export interface MessageEmojis { } export interface Message extends MessageEmojis { + id: string; text: string; is_assistant: boolean; - id: string; - created_date: string; // iso date string lang: string; + created_date: string; // iso date string + parent_id: string; frontend_message_id?: string; } diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts index a791916e..bbbe3a67 100644 --- a/website/src/types/Tasks.ts +++ b/website/src/types/Tasks.ts @@ -1,4 +1,4 @@ -import { Conversation } from "./Conversation"; +import { Conversation, Message } from "./Conversation"; import { BaseTask, TaskType } from "./Task"; export interface CreateInitialPromptTask extends BaseTask { @@ -37,6 +37,7 @@ export interface LabelAssistantReplyTask extends BaseTask { type: TaskType.label_assistant_reply; message_id: string; conversation: Conversation; + reply_message: Message; reply: string; valid_labels: string[]; mode: "simple" | "full"; @@ -47,6 +48,7 @@ export interface LabelPrompterReplyTask extends BaseTask { type: TaskType.label_prompter_reply; message_id: string; conversation: Conversation; + reply_message: Message; reply: string; valid_labels: string[]; mode: "simple" | "full";