Send x-oasst-user header

This commit is contained in:
AbdBarho
2023-01-27 18:18:17 +01:00
parent 3a32a10b23
commit cdb9f2da4e
15 changed files with 70 additions and 44 deletions
+2 -2
View File
@@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe
* Wraps any API Route handler and verifies that the user has the appropriate
* role before running the handler. Returns a 403 otherwise.
*/
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => {
return async (req: NextApiRequest, res: NextApiResponse) => {
const token = await getToken({ req });
if (!token || token.role !== role) {
res.status(403).end();
return;
}
return handler(req, res);
return handler(req, res, token);
};
};
+18 -6
View File
@@ -1,8 +1,11 @@
import { JWT } from "next-auth/jwt";
import type { EmojiOp, Message } from "src/types/Conversation";
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import type { AvailableTasks } from "src/types/Task";
import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users";
import { getBackendUserCore } from "./users";
export class OasstError {
message: string;
errorCode: number;
@@ -18,10 +21,16 @@ export class OasstError {
export class OasstApiClient {
oasstApiUrl: string;
oasstApiKey: string;
userHeaders: Record<string, string> = {};
constructor(oasstApiUrl: string, oasstApiKey: string) {
constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) {
this.oasstApiUrl = oasstApiUrl;
this.oasstApiKey = oasstApiKey;
if (user) {
this.userHeaders = {
"X-OASST-USER": `${user.auth_method}:${user.id}`,
};
}
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
@@ -215,9 +224,10 @@ export class OasstApiClient {
method,
...init,
headers: {
...init?.headers,
...this.userHeaders,
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
...init?.headers,
},
});
@@ -227,8 +237,7 @@ export class OasstApiClient {
if (resp.status >= 300) {
const errorText = await resp.text();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let error: any;
let error;
try {
error = JSON.parse(errorText);
} catch (e) {
@@ -241,6 +250,9 @@ export class OasstApiClient {
}
}
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
export const createApiClientFromUser = (user: BackendUserCore) =>
new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user);
export { oasstApiClient };
export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub));
export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
+3 -3
View File
@@ -1,17 +1,17 @@
import { getToken } from "next-auth/jwt";
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { getBackendUserCore } from "src/lib/users";
import { createApiClientFromUser } from "src/lib/oasst_api_client";
/**
* Returns tasks availability, stats, and tree manager stats.
*/
const handler = withRole("admin", async (req, res) => {
// NOTE: why are we using a dummy user here?
const dummyUser = {
id: "__dummy_user__",
display_name: "Dummy User",
auth_method: "local",
};
const oasstApiClient = createApiClientFromUser(dummyUser);
const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([
oasstApiClient.fetch_tasks_availability(dummyUser),
oasstApiClient.fetch_stats(),
+5 -8
View File
@@ -1,22 +1,19 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
/**
* Update's the user's data in the database. Accessible only to admins.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { id, auth_method, user_id, notes, role } = req.body;
const oasstApiClient = await createApiClient(token);
// If the user is authorized by the web, update their role.
if (auth_method === "local") {
await prisma.user.update({
where: {
id,
},
data: {
role,
},
where: { id },
data: { role },
});
}
// Tell the backend the user's enabled or not enabled status.
+3 -2
View File
@@ -1,12 +1,13 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import type { Message } from "src/types/Conversation";
/**
* Returns the messages recorded by the backend for a user.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { user } = req.query;
const oasstApiClient = await createApiClient(token);
const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string);
res.status(200).json(messages);
});
+3 -2
View File
@@ -1,5 +1,5 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
import { FetchUsersParams } from "src/types/Users";
@@ -17,9 +17,10 @@ const PAGE_SIZE = 20;
* - `direction`: Either "forward" or "backward" representing the pagination
* direction.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query;
const oasstApiClient = await createApiClient(token);
// First, get all the users according to the backend.
const { items: all_users, ...rest } = await oasstApiClient.fetch_users({
searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"],
+2 -1
View File
@@ -1,9 +1,10 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_api_client";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
const userLanguage = getUserLanguage(req);
const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage);
res.status(200).json(availableTasks);
+3 -2
View File
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
/**
* Returns the set of valid labels that can be applied to messages.
*/
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
const oasstApiClient = await createApiClient(token);
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number });
res.status(200).json(info);
+2
View File
@@ -1,6 +1,8 @@
import { withoutRole } from "src/lib/auth";
const handler = withoutRole("banned", async (req, res) => {
// TODO: move to oasst_api_client
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
method: "GET",
headers: {
+2 -1
View File
@@ -2,13 +2,14 @@ import { withoutRole } from "src/lib/auth";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
//TODO: add params if needed
const user = await getBackendUserCore(token.sub);
const params = new URLSearchParams({
username: user.id,
auth_method: user.auth_method,
});
// TODO: move to oasst_api_client
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
method: "GET",
headers: {
@@ -1,5 +1,5 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
const userLanguage = getUserLanguage(req);
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
let task;
try {
task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage);
+8 -6
View File
@@ -1,19 +1,21 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, reason } = req.body;
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const [oasstApiClient, registeredTask] = await Promise.all([
createApiClient(token),
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
]);
const task = registeredTask.task as Prisma.JsonObject;
const id = task.id as string;
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
// Update the backend with the rejection
await oasstApiClient.nackTask(id, reason);
await oasstApiClient.nackTask(taskId, reason);
// Send the results to the client.
res.status(200).json({});
+1
View File
@@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth";
*
*/
const handler = withoutRole("banned", async (req, res, token) => {
// TODO: move to oasst_api_client
// Parse out the local message_id, and the interaction contents.
const { message_id, label_map } = req.body;
+12 -7
View File
@@ -1,6 +1,6 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, content, update_type } = req.body;
// Record that the user has done meaningful work and is no longer new.
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
// do in parallel since they are independent
const [_, registeredTask, oasstApiClient] = await Promise.all([
// Record that the user has done meaningful work and is no longer new.
prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }),
// Accept the task so that we can complete it, this will probably go away soon.
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
// Create client for upcoming requests
createApiClient(token),
]);
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
// Accept the task so that we can complete it, this will probably go away soon.
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const task = registeredTask.task as Prisma.JsonObject;
const taskId = task.id as string;
await oasstApiClient.ackTask(taskId, registeredTask.id);
// Log the interaction locally to create our user_post_id needed by the Task
+4 -3
View File
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_api_client";
/**
* Returns the set of valid labels that can be applied to messages.
*/
const handler = withoutRole("banned", async (req, res) => {
const valid_labels = await oasstApiClient.fetch_valid_text();
const handler = withoutRole("banned", async (req, res, token) => {
const client = await createApiClient(token);
const valid_labels = await client.fetch_valid_text();
res.status(200).json(valid_labels);
});