mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Send x-oasst-user header
This commit is contained in:
@@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
* role before running the handler. Returns a 403 otherwise.
|
||||
*/
|
||||
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role !== role) {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
return handler(req, res);
|
||||
return handler(req, res, token);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { JWT } from "next-auth/jwt";
|
||||
import type { EmojiOp, Message } from "src/types/Conversation";
|
||||
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import type { AvailableTasks } from "src/types/Task";
|
||||
import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users";
|
||||
|
||||
import { getBackendUserCore } from "./users";
|
||||
|
||||
export class OasstError {
|
||||
message: string;
|
||||
errorCode: number;
|
||||
@@ -18,10 +21,16 @@ export class OasstError {
|
||||
export class OasstApiClient {
|
||||
oasstApiUrl: string;
|
||||
oasstApiKey: string;
|
||||
userHeaders: Record<string, string> = {};
|
||||
|
||||
constructor(oasstApiUrl: string, oasstApiKey: string) {
|
||||
constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) {
|
||||
this.oasstApiUrl = oasstApiUrl;
|
||||
this.oasstApiKey = oasstApiKey;
|
||||
if (user) {
|
||||
this.userHeaders = {
|
||||
"X-OASST-USER": `${user.auth_method}:${user.id}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to store a task in RegisteredTask.task.
|
||||
@@ -215,9 +224,10 @@ export class OasstApiClient {
|
||||
method,
|
||||
...init,
|
||||
headers: {
|
||||
...init?.headers,
|
||||
...this.userHeaders,
|
||||
"X-API-Key": this.oasstApiKey,
|
||||
"Content-Type": "application/json",
|
||||
...init?.headers,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -227,8 +237,7 @@ export class OasstApiClient {
|
||||
|
||||
if (resp.status >= 300) {
|
||||
const errorText = await resp.text();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let error: any;
|
||||
let error;
|
||||
try {
|
||||
error = JSON.parse(errorText);
|
||||
} catch (e) {
|
||||
@@ -241,6 +250,9 @@ export class OasstApiClient {
|
||||
}
|
||||
}
|
||||
|
||||
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
export const createApiClientFromUser = (user: BackendUserCore) =>
|
||||
new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user);
|
||||
|
||||
export { oasstApiClient };
|
||||
export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub));
|
||||
|
||||
export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_api_client";
|
||||
|
||||
/**
|
||||
* Returns tasks availability, stats, and tree manager stats.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
// NOTE: why are we using a dummy user here?
|
||||
const dummyUser = {
|
||||
id: "__dummy_user__",
|
||||
display_name: "Dummy User",
|
||||
auth_method: "local",
|
||||
};
|
||||
const oasstApiClient = createApiClientFromUser(dummyUser);
|
||||
const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([
|
||||
oasstApiClient.fetch_tasks_availability(dummyUser),
|
||||
oasstApiClient.fetch_stats(),
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Update's the user's data in the database. Accessible only to admins.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { id, auth_method, user_id, notes, role } = req.body;
|
||||
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
// If the user is authorized by the web, update their role.
|
||||
if (auth_method === "local") {
|
||||
await prisma.user.update({
|
||||
where: {
|
||||
id,
|
||||
},
|
||||
data: {
|
||||
role,
|
||||
},
|
||||
where: { id },
|
||||
data: { role },
|
||||
});
|
||||
}
|
||||
// Tell the backend the user's enabled or not enabled status.
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import type { Message } from "src/types/Conversation";
|
||||
|
||||
/**
|
||||
* Returns the messages recorded by the backend for a user.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { user } = req.query;
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string);
|
||||
res.status(200).json(messages);
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { FetchUsersParams } from "src/types/Users";
|
||||
|
||||
@@ -17,9 +17,10 @@ const PAGE_SIZE = 20;
|
||||
* - `direction`: Either "forward" or "backward" representing the pagination
|
||||
* direction.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const handler = withRole("admin", async (req, res, token) => {
|
||||
const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query;
|
||||
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
// First, get all the users according to the backend.
|
||||
const { items: all_users, ...rest } = await oasstApiClient.fetch_users({
|
||||
searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"],
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_api_client";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
const userLanguage = getUserLanguage(req);
|
||||
const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage);
|
||||
res.status(200).json(availableTasks);
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
/**
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const oasstApiClient = await createApiClient(token);
|
||||
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
|
||||
const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number });
|
||||
res.status(200).json(info);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
// TODO: move to oasst_api_client
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
|
||||
@@ -2,13 +2,14 @@ import { withoutRole } from "src/lib/auth";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
//TODO: add params if needed
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const params = new URLSearchParams({
|
||||
username: user.id,
|
||||
auth_method: user.auth_method,
|
||||
});
|
||||
|
||||
// TODO: move to oasst_api_client
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClientFromUser } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
@@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const userLanguage = getUserLanguage(req);
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const oasstApiClient = createApiClientFromUser(user);
|
||||
let task;
|
||||
try {
|
||||
task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage);
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, reason } = req.body;
|
||||
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const [oasstApiClient, registeredTask] = await Promise.all([
|
||||
createApiClient(token),
|
||||
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
|
||||
]);
|
||||
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
|
||||
|
||||
// Update the backend with the rejection
|
||||
await oasstApiClient.nackTask(id, reason);
|
||||
await oasstApiClient.nackTask(taskId, reason);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json({});
|
||||
|
||||
@@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth";
|
||||
*
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// TODO: move to oasst_api_client
|
||||
// Parse out the local message_id, and the interaction contents.
|
||||
const { message_id, label_map } = req.body;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
|
||||
|
||||
@@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, content, update_type } = req.body;
|
||||
|
||||
// Record that the user has done meaningful work and is no longer new.
|
||||
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
|
||||
// do in parallel since they are independent
|
||||
const [_, registeredTask, oasstApiClient] = await Promise.all([
|
||||
// Record that the user has done meaningful work and is no longer new.
|
||||
prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }),
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
|
||||
// Create client for upcoming requests
|
||||
createApiClient(token),
|
||||
]);
|
||||
|
||||
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
|
||||
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const taskId = task.id as string;
|
||||
await oasstApiClient.ackTask(taskId, registeredTask.id);
|
||||
|
||||
// Log the interaction locally to create our user_post_id needed by the Task
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { createApiClient } from "src/lib/oasst_api_client";
|
||||
|
||||
/**
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const valid_labels = await oasstApiClient.fetch_valid_text();
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const client = await createApiClient(token);
|
||||
const valid_labels = await client.fetch_valid_text();
|
||||
res.status(200).json(valid_labels);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user