Merge pull request #959 from LAION-AI/oasst-header

Send `x-oasst-user` header to backend
This commit is contained in:
Keith Stevens
2023-01-28 16:23:07 +09:00
committed by GitHub
28 changed files with 134 additions and 154 deletions
@@ -34,7 +34,7 @@ interface MessageTableEntryProps {
export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) {
const router = useRouter();
const [emojis, setEmojis] = useState<MessageEmojis>({ emojis: {}, user_emojis: [] });
const [emojiState, setEmojis] = useState<MessageEmojis>({ emojis: {}, user_emojis: [] });
useEffect(() => {
setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis });
}, [message.emojis, message.user_emojis]);
@@ -93,17 +93,17 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }}
onClick={(e) => e.stopPropagation()}
>
{Object.entries(emojis.emojis).map(([emoji, count]) => (
{Object.entries(emojiState.emojis).map(([emoji, count]) => (
<MessageEmojiButton
key={emoji}
emoji={{ name: emoji, count }}
checked={emojis.user_emojis.includes(emoji)}
onClick={() => react(emoji, !emojis.user_emojis.includes(emoji))}
checked={emojiState.user_emojis.includes(emoji)}
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
/>
))}
<MessageActions
react={react}
userEmoji={emojis.user_emojis}
userEmoji={emojiState.user_emojis}
onLabel={showLabelPopup}
onReport={showReportPopup}
messageId={message.id}
@@ -211,7 +211,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label
}}
alignItems="center"
>
<Text>
<Text as="div">
{textA}
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
</Text>
@@ -229,7 +229,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label
/>
</GridItem>
<GridItem>
<Text textAlign="right">
<Text textAlign="right" as="div">
{textB}
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
</Text>
@@ -6,7 +6,6 @@ import { LabelInputGroup } from "src/components/Survey/LabelInputGroup";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskSurveyProps } from "src/components/Tasks/Task";
import { TaskHeader } from "src/components/Tasks/TaskHeader";
import { TaskType } from "src/types/Task";
export const LabelTask = ({
task,
@@ -35,14 +34,7 @@ export const LabelTask = ({
{task.conversation ? (
<Box mt="4" p={[4, 6]} borderRadius="lg" bg={cardColor}>
<MessageTable
messages={[
...(task.conversation?.messages ?? []),
{
text: task.reply,
is_assistant: task.type === TaskType.label_assistant_reply,
message_id: task.message_id,
},
]}
messages={[...(task.conversation?.messages ?? []), task.reply_message]}
highlightLastMessage
/>
</Box>
+2 -2
View File
@@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe
* Wraps any API Route handler and verifies that the user has the appropriate
* role before running the handler. Returns a 403 otherwise.
*/
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => {
return async (req: NextApiRequest, res: NextApiResponse) => {
const token = await getToken({ req });
if (!token || token.role !== role) {
res.status(403).end();
return;
}
return handler(req, res);
return handler(req, res, token);
};
};
+30 -8
View File
@@ -18,10 +18,16 @@ export class OasstError {
export class OasstApiClient {
oasstApiUrl: string;
oasstApiKey: string;
userHeaders: Record<string, string> = {};
constructor(oasstApiUrl: string, oasstApiKey: string) {
constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) {
this.oasstApiUrl = oasstApiUrl;
this.oasstApiKey = oasstApiKey;
if (user) {
this.userHeaders = {
"X-OASST-USER": `${user.auth_method}:${user.id}`,
};
}
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
@@ -215,9 +221,10 @@ export class OasstApiClient {
method,
...init,
headers: {
...init?.headers,
...this.userHeaders,
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
...init?.headers,
},
});
@@ -227,8 +234,7 @@ export class OasstApiClient {
if (resp.status >= 300) {
const errorText = await resp.text();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let error: any;
let error;
try {
error = JSON.parse(errorText);
} catch (e) {
@@ -239,8 +245,24 @@ export class OasstApiClient {
return await resp.json();
}
fetch_my_messages(user: BackendUserCore) {
const params = new URLSearchParams({
username: user.id,
auth_method: user.auth_method,
});
return this.get<Message[]>(`/api/v1/messages?${params}`);
}
fetch_recent_messages() {
return this.get<Message[]>(`/api/v1/messages`);
}
fetch_message_children(messageId: string) {
return this.get<Message[]>(`/api/v1/messages/${messageId}/children`);
}
fetch_conversation(messageId: string) {
return this.get(`/api/v1/messages/${messageId}/conversation`);
}
}
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
export { oasstApiClient };
+11
View File
@@ -0,0 +1,11 @@
import { JWT } from "next-auth/jwt";
import { OasstApiClient } from "src/lib/oasst_api_client";
import { getBackendUserCore } from "src/lib/users";
import { BackendUserCore } from "src/types/Users";
export const createApiClientFromUser = (user: BackendUserCore) =>
new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user);
export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub));
export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
+2 -2
View File
@@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout";
import { Role, RoleSelect } from "src/components/RoleSelect";
import { UserMessagesCell } from "src/components/UserMessagesCell";
import { post } from "src/lib/api";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { userlessApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import useSWRMutation from "swr/mutation";
@@ -113,7 +113,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType<typeof getServerSidePr
* Fetch the user's data on the server side when rendering.
*/
export async function getServerSideProps({ query, locale }) {
const backend_user = await oasstApiClient.fetch_user(query.id);
const backend_user = await userlessApiClient.fetch_user(query.id);
const local_user = await prisma.user.findUnique({
where: { id: backend_user.id },
select: {
+3 -3
View File
@@ -1,17 +1,17 @@
import { getToken } from "next-auth/jwt";
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { getBackendUserCore } from "src/lib/users";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
/**
* Returns tasks availability, stats, and tree manager stats.
*/
const handler = withRole("admin", async (req, res) => {
// NOTE: why are we using a dummy user here?
const dummyUser = {
id: "__dummy_user__",
display_name: "Dummy User",
auth_method: "local",
};
const oasstApiClient = createApiClientFromUser(dummyUser);
const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([
oasstApiClient.fetch_tasks_availability(dummyUser),
oasstApiClient.fetch_stats(),
+5 -8
View File
@@ -1,22 +1,19 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
/**
* Update's the user's data in the database. Accessible only to admins.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { id, auth_method, user_id, notes, role } = req.body;
const oasstApiClient = await createApiClient(token);
// If the user is authorized by the web, update their role.
if (auth_method === "local") {
await prisma.user.update({
where: {
id,
},
data: {
role,
},
where: { id },
data: { role },
});
}
// Tell the backend the user's enabled or not enabled status.
+3 -2
View File
@@ -1,12 +1,13 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import type { Message } from "src/types/Conversation";
/**
* Returns the messages recorded by the backend for a user.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { user } = req.query;
const oasstApiClient = await createApiClient(token);
const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string);
res.status(200).json(messages);
});
+3 -2
View File
@@ -1,5 +1,5 @@
import { withRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { FetchUsersParams } from "src/types/Users";
@@ -17,9 +17,10 @@ const PAGE_SIZE = 20;
* - `direction`: Either "forward" or "backward" representing the pagination
* direction.
*/
const handler = withRole("admin", async (req, res) => {
const handler = withRole("admin", async (req, res, token) => {
const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query;
const oasstApiClient = await createApiClient(token);
// First, get all the users according to the backend.
const { items: all_users, ...rest } = await oasstApiClient.fetch_users({
searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"],
+2 -1
View File
@@ -1,9 +1,10 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
const userLanguage = getUserLanguage(req);
const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage);
res.status(200).json(availableTasks);
+3 -2
View File
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
/**
* Returns the set of valid labels that can be applied to messages.
*/
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
const oasstApiClient = await createApiClient(token);
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number });
res.status(200).json(info);
@@ -1,18 +1,10 @@
import { withoutRole } from "src/lib/auth";
import { createApiClient } from "src/lib/oasst_client_factory";
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
const client = await createApiClient(token);
const messages = await client.fetch_message_children(id as string);
res.status(200).json(messages);
});
@@ -1,18 +1,10 @@
import { withoutRole } from "src/lib/auth";
import { createApiClient } from "src/lib/oasst_client_factory";
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
const client = await createApiClient(token);
const messages = await client.fetch_conversation(id as string);
res.status(200).json(messages);
});
+2 -1
View File
@@ -1,5 +1,5 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
@@ -15,6 +15,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
const { emoji, op } = req.body;
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
try {
await oasstApiClient.set_user_message_emoji(messageId, user, emoji, op);
} catch (err) {
+3 -16
View File
@@ -1,25 +1,12 @@
import { withoutRole } from "src/lib/auth";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
const user = await getBackendUserCore(token.sub);
const params = new URLSearchParams({
username: user.id,
auth_method: user.auth_method,
});
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/?${params}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const message = await messageRes.json();
// Send recieved messages to the client.
const client = createApiClientFromUser(user);
const message = await client.fetch_message(id as string, user);
res.status(200).json(message);
});
+7 -21
View File
@@ -1,6 +1,8 @@
import { withoutRole } from "src/lib/auth";
import { createApiClient, createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
if (!id) {
@@ -8,32 +10,16 @@ const handler = withoutRole("banned", async (req, res) => {
return;
}
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const message = await messageRes.json();
const user = await getBackendUserCore(token.sub);
const client = createApiClientFromUser(user);
const message = await client.fetch_message(id as string, user);
if (!message.parent_id) {
res.status(404).end();
return;
}
const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
});
const parent = await parentRes.json();
// Send recieved messages to the client.
const parent = await client.fetch_message(message.parent_id, user);
res.status(200).json(parent);
});
+4 -10
View File
@@ -1,15 +1,9 @@
import { withoutRole } from "src/lib/auth";
import { createApiClient } from "src/lib/oasst_client_factory";
const handler = withoutRole("banned", async (req, res) => {
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
const handler = withoutRole("banned", async (req, res, token) => {
const client = await createApiClient(token);
const messages = await client.fetch_recent_messages();
res.status(200).json(messages);
});
+3 -15
View File
@@ -1,23 +1,11 @@
import { withoutRole } from "src/lib/auth";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
//TODO: add params if needed
const user = await getBackendUserCore(token.sub);
const params = new URLSearchParams({
username: user.id,
auth_method: user.auth_method,
});
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
const client = createApiClientFromUser(user);
const messages = await client.fetch_my_messages(user);
res.status(200).json(messages);
});
@@ -1,5 +1,5 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
const userLanguage = getUserLanguage(req);
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
let task;
try {
task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage);
+8 -6
View File
@@ -1,19 +1,21 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
const handler = withoutRole("banned", async (req, res) => {
const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, reason } = req.body;
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const [oasstApiClient, registeredTask] = await Promise.all([
createApiClient(token),
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
]);
const task = registeredTask.task as Prisma.JsonObject;
const id = task.id as string;
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
// Update the backend with the rejection
await oasstApiClient.nackTask(id, reason);
await oasstApiClient.nackTask(taskId, reason);
// Send the results to the client.
res.status(200).json({});
+2 -1
View File
@@ -1,5 +1,5 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
/**
@@ -11,6 +11,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
const { message_id, text } = req.body;
const user = await getBackendUserCore(token.sub);
const oasstApiClient = createApiClientFromUser(user);
try {
await oasstApiClient.send_report(message_id, user, text);
} catch (err) {
+1
View File
@@ -5,6 +5,7 @@ import { withoutRole } from "src/lib/auth";
*
*/
const handler = withoutRole("banned", async (req, res, token) => {
// TODO: move to oasst_api_client
// Parse out the local message_id, and the interaction contents.
const { message_id, label_map } = req.body;
+12 -7
View File
@@ -1,6 +1,6 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, content, update_type } = req.body;
// Record that the user has done meaningful work and is no longer new.
await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
// do in parallel since they are independent
const [_, registeredTask, oasstApiClient] = await Promise.all([
// Record that the user has done meaningful work and is no longer new.
prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }),
// Accept the task so that we can complete it, this will probably go away soon.
prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
// Create client for upcoming requests
createApiClient(token),
]);
const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
// Accept the task so that we can complete it, this will probably go away soon.
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const task = registeredTask.task as Prisma.JsonObject;
const taskId = task.id as string;
await oasstApiClient.ackTask(taskId, registeredTask.id);
// Log the interaction locally to create our user_post_id needed by the Task
+4 -3
View File
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { createApiClient } from "src/lib/oasst_client_factory";
/**
* Returns the set of valid labels that can be applied to messages.
*/
const handler = withoutRole("banned", async (req, res) => {
const valid_labels = await oasstApiClient.fetch_valid_text();
const handler = withoutRole("banned", async (req, res, token) => {
const client = await createApiClient(token);
const valid_labels = await client.fetch_valid_text();
res.status(200).json(valid_labels);
});
+3 -2
View File
@@ -11,11 +11,12 @@ export interface MessageEmojis {
}
export interface Message extends MessageEmojis {
id: string;
text: string;
is_assistant: boolean;
id: string;
created_date: string; // iso date string
lang: string;
created_date: string; // iso date string
parent_id: string;
frontend_message_id?: string;
}
+3 -1
View File
@@ -1,4 +1,4 @@
import { Conversation } from "./Conversation";
import { Conversation, Message } from "./Conversation";
import { BaseTask, TaskType } from "./Task";
export interface CreateInitialPromptTask extends BaseTask {
@@ -37,6 +37,7 @@ export interface LabelAssistantReplyTask extends BaseTask {
type: TaskType.label_assistant_reply;
message_id: string;
conversation: Conversation;
reply_message: Message;
reply: string;
valid_labels: string[];
mode: "simple" | "full";
@@ -47,6 +48,7 @@ export interface LabelPrompterReplyTask extends BaseTask {
type: TaskType.label_prompter_reply;
message_id: string;
conversation: Conversation;
reply_message: Message;
reply: string;
valid_labels: string[];
mode: "simple" | "full";