(null);
let messages = [];
- if (task.conversation) {
+ if (task.type !== TaskType.rank_initial_prompts) {
messages = task.conversation.messages;
- messages = messages.map((message, index) => ({ ...message, id: index }));
}
useEffect(() => {
if (ranking === null) {
- const defaultRanking = (task.replies ?? task.prompts).map((_, idx) => idx);
- onReplyChanged({ ranking: defaultRanking });
+ if (task.type === TaskType.rank_initial_prompts) {
+ onReplyChanged({ ranking: task.prompts.map((_, idx) => idx) });
+ } else {
+ onReplyChanged({ ranking: task.replies.map((_, idx) => idx) });
+ }
onValidityChanged("DEFAULT");
} else {
onReplyChanged({ ranking });
@@ -34,7 +41,7 @@ export const EvaluateTask = ({
}
}, [task, ranking, onReplyChanged, onValidityChanged]);
- const sortables = task.replies ? "replies" : "prompts";
+ const sortables = task.type === TaskType.rank_initial_prompts ? "prompts" : "replies";
return (
diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx
index 2a08eb66..33152ba1 100644
--- a/website/src/components/Tasks/LabelTask/LabelTask.tsx
+++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx
@@ -1,12 +1,18 @@
-import { Box, Button, Flex, HStack, Text, useColorModeValue } from "@chakra-ui/react";
+import { Box, useBoolean, useColorModeValue } from "@chakra-ui/react";
+import { useTranslation } from "next-i18next";
import { useEffect, useState } from "react";
import { MessageView } from "src/components/Messages";
+import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
import { MessageTable } from "src/components/Messages/MessageTable";
-import { LabelInputGroup } from "src/components/Survey/LabelInputGroup";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskSurveyProps } from "src/components/Tasks/Task";
import { TaskHeader } from "src/components/Tasks/TaskHeader";
import { TaskType } from "src/types/Task";
+import { LabelTaskType } from "src/types/Tasks";
+
+const isRequired = (labelName: string, requiredLabels?: string[]) => {
+ return requiredLabels ? requiredLabels.includes(labelName) : false;
+};
export const LabelTask = ({
task,
@@ -14,15 +20,33 @@ export const LabelTask = ({
isEditable,
onReplyChanged,
onValidityChanged,
-}: TaskSurveyProps<{ text: string; labels: Record; message_id: string }>) => {
- const [sliderValues, setSliderValues] = useState(new Array(task.valid_labels.length).fill(null));
+}: TaskSurveyProps; message_id: string }>) => {
+ const { t } = useTranslation("labelling");
+ const [values, setValues] = useState(new Array(task.labels.length).fill(null));
+ const [userInputMade, setUserInputMade] = useBoolean(false);
+ // Initial setup to run when the task changes
useEffect(() => {
- console.assert(task.valid_labels.length === sliderValues.length);
- const labels = Object.fromEntries(task.valid_labels.map((label, i) => [label, sliderValues[i]]));
- onReplyChanged({ labels, text: task.reply || task.prompt, message_id: task.message_id });
- onValidityChanged(sliderValues.every((value) => value !== null) ? "VALID" : "INVALID");
- }, [task, sliderValues, onReplyChanged, onValidityChanged]);
+ setValues(new Array(task.labels.length).fill(null));
+ onValidityChanged(task.labels.some(({ name }) => isRequired(name, task.mandatory_labels)) ? "INVALID" : "DEFAULT");
+ setUserInputMade.off();
+ }, [task, setUserInputMade, onValidityChanged]);
+
+ // Update the reply and validity when the values change
+ useEffect(() => {
+ onReplyChanged({
+ text: "unused?",
+ labels: Object.fromEntries(task.labels.map(({ name }, idx) => [name, values[idx] || 0])),
+ message_id: task.message_id,
+ });
+ onValidityChanged(
+ task.labels.some(({ name }, idx) => values[idx] === null && isRequired(name, task.mandatory_labels))
+ ? "INVALID"
+ : userInputMade
+ ? "VALID"
+ : "DEFAULT"
+ );
+ }, [task, values, onReplyChanged, userInputMade, onValidityChanged]);
const cardColor = useColorModeValue("gray.50", "gray.800");
const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam";
@@ -32,71 +56,32 @@ export const LabelTask = ({
<>
- {task.conversation ? (
+ {task.type !== TaskType.label_initial_prompt ? (
-
+
) : (
-
+
)}
>
- {isSpamTask ? (
- setSliderValues([value])}
- isEditable={isEditable}
- />
- ) : (
-
- The highlighted message:
-
-
- )}
+ {
+ setValues(values);
+ setUserInputMade.on();
+ }}
+ />
);
};
-
-const SpamTaskInput = ({
- isEditable,
- value,
- onChange,
-}: {
- isEditable: boolean;
- value: number;
- onChange: (number) => void;
-}) => {
- return (
-
- Is the highlighted message spam?
-
-
-
- );
-};
diff --git a/website/src/components/Tasks/Task/Task.tsx b/website/src/components/Tasks/Task/Task.tsx
index ae82ef97..51ba6fa3 100644
--- a/website/src/components/Tasks/Task/Task.tsx
+++ b/website/src/components/Tasks/Task/Task.tsx
@@ -1,5 +1,6 @@
import { useTranslation } from "next-i18next";
-import { useRef, useState } from "react";
+import { useCallback, useEffect, useReducer } from "react";
+import { useMemo, useRef } from "react";
import { TaskControls } from "src/components/Survey/TaskControls";
import { CreateTask } from "src/components/Tasks/CreateTask";
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
@@ -8,15 +9,52 @@ import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskType
import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning";
import { post } from "src/lib/api";
import { getTypeSafei18nKey } from "src/lib/i18n";
-import { TaskContent, TaskReplyValidity } from "src/types/Task";
+import { BaseTask, TaskContent, TaskReplyValidity } from "src/types/Task";
import useSWRMutation from "swr/mutation";
-export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "VALID" | "REVIEW" | "SUBMITTED";
+interface EditMode {
+ mode: "EDIT";
+ replyValidity: TaskReplyValidity;
+}
+interface ReviewMode {
+ mode: "REVIEW";
+}
+interface DefaultWarnMode {
+ mode: "DEFAULT_WARN";
+}
+interface SubmittedMode {
+ mode: "SUBMITTED";
+}
-export interface TaskSurveyProps {
- // we need a task type
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- task: any;
+export type TaskStatus = EditMode | DefaultWarnMode | ReviewMode | SubmittedMode;
+
+interface NewTask {
+ action: "NEW_TASK";
+}
+
+interface Review {
+ action: "REVIEW";
+}
+
+interface SetSubmitted {
+ action: "SET_SUBMITTED";
+}
+
+interface ReturnToEdit {
+ action: "RETURN_EDIT";
+}
+
+interface AcceptDefault {
+ action: "ACCEPT_DEFAULT";
+}
+
+interface UpdateValidity {
+ action: "UPDATE_VALIDITY";
+ replyValidity: TaskReplyValidity;
+}
+
+export interface TaskSurveyProps {
+ task: TaskType;
taskType: TaskInfo;
isEditable: boolean;
isDisabled?: boolean;
@@ -26,13 +64,63 @@ export interface TaskSurveyProps {
export const Task = ({ frontendId, task, trigger, mutate }) => {
const { t } = useTranslation("tasks");
- const [taskStatus, setTaskStatus] = useState("NOT_SUBMITTABLE");
+ const [taskStatus, taskEvent] = useReducer(
+ (
+ status: TaskStatus,
+ event: NewTask | UpdateValidity | AcceptDefault | Review | ReturnToEdit | SetSubmitted
+ ): TaskStatus => {
+ switch (event.action) {
+ case "NEW_TASK":
+ return { mode: "EDIT", replyValidity: "INVALID" };
+ case "UPDATE_VALIDITY":
+ return status.mode === "EDIT" ? { mode: "EDIT", replyValidity: event.replyValidity } : status;
+ case "ACCEPT_DEFAULT":
+ return status.mode === "DEFAULT_WARN" ? { mode: "REVIEW" } : status;
+ case "REVIEW": {
+ if (status.mode === "EDIT") {
+ switch (status.replyValidity) {
+ case "DEFAULT":
+ return { mode: "DEFAULT_WARN" };
+ case "VALID":
+ return { mode: "REVIEW" };
+ }
+ }
+ return status;
+ }
+ case "RETURN_EDIT": {
+ switch (status.mode) {
+ case "REVIEW":
+ return { mode: "EDIT", replyValidity: "VALID" };
+ case "DEFAULT_WARN":
+ return { mode: "EDIT", replyValidity: "DEFAULT" };
+ default:
+ return status;
+ }
+ }
+ case "SET_SUBMITTED": {
+ return status.mode === "REVIEW" ? { mode: "SUBMITTED" } : status;
+ }
+ }
+ },
+ { mode: "EDIT", replyValidity: "INVALID" }
+ );
+
const replyContent = useRef(null);
- const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
+ const updateValidity = useCallback(
+ (replyValidity: TaskReplyValidity) => taskEvent({ action: "UPDATE_VALIDITY", replyValidity }),
+ [taskEvent]
+ );
+
+ useEffect(() => {
+ taskEvent({ action: "NEW_TASK" });
+ }, [task.id, updateValidity]);
const rootEl = useRef(null);
- const taskType = TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode);
+ const taskType = useMemo(
+ () => TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode),
+ [task.type, task.mode]
+ );
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, {
onSuccess: async () => {
@@ -47,79 +135,36 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
});
};
- const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID";
- const submitted = taskStatus === "SUBMITTED";
-
- const onValidityChanged = (validity: TaskReplyValidity) => {
- if (!edit_mode) return;
- switch (validity) {
- case "DEFAULT":
- if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT");
- break;
- case "VALID":
- if (taskStatus !== "VALID") setTaskStatus("VALID");
- break;
- case "INVALID":
- if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE");
- break;
- }
- };
-
- const onReplyChanged = (content: TaskContent) => {
- replyContent.current = content;
- };
-
- const reviewResponse = () => {
- switch (taskStatus) {
- case "DEFAULT":
- setShowUnchangedWarning(true);
- break;
- case "VALID":
- setTaskStatus("REVIEW");
- break;
- default:
- return;
- }
- };
-
- const editResponse = () => {
- switch (taskStatus) {
- case "REVIEW":
- setTaskStatus("VALID");
- break;
- default:
- return;
- }
- };
+ const onReplyChanged = useCallback(
+ (content: TaskContent) => {
+ replyContent.current = content;
+ },
+ [replyContent]
+ );
const submitResponse = () => {
- switch (taskStatus) {
- case "REVIEW": {
- trigger({
- id: frontendId,
- update_type: taskType.update_type,
- content: replyContent.current,
- });
- setTaskStatus("SUBMITTED");
- scrollToTop(rootEl.current);
- break;
- }
- default:
- return;
+ if (taskStatus.mode === "REVIEW") {
+ trigger({
+ id: frontendId,
+ update_type: taskType.update_type,
+ content: replyContent.current,
+ });
+ taskEvent({ action: "SET_SUBMITTED" });
+ scrollToTop(rootEl.current);
}
};
- function taskTypeComponent() {
+ const taskTypeComponent = useMemo(() => {
switch (taskType.category) {
case TaskCategory.Create:
return (
);
case TaskCategory.Evaluate:
@@ -127,10 +172,10 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
);
case TaskCategory.Label:
@@ -138,37 +183,34 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
);
}
- }
+ }, [task, taskType, taskStatus.mode, onReplyChanged, updateValidity]);
return (
- {taskTypeComponent()}
+ {taskTypeComponent}
taskEvent({ action: "RETURN_EDIT" })}
+ onReview={() => taskEvent({ action: "REVIEW" })}
onSubmit={submitResponse}
onSkip={rejectTask}
/>
setShowUnchangedWarning(false)}
+ onClose={() => taskEvent({ action: "RETURN_EDIT" })}
onContinueAnyway={() => {
- if (taskStatus === "DEFAULT") {
- setTaskStatus("REVIEW");
- setShowUnchangedWarning(false);
- }
+ taskEvent({ action: "ACCEPT_DEFAULT" });
}}
/>
diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts
index d61016d2..2649daf8 100644
--- a/website/src/lib/api.ts
+++ b/website/src/lib/api.ts
@@ -17,7 +17,8 @@ export const post = (url: string, { arg: data }) => api.post(url, data).then((re
api.interceptors.response.use(
(response) => response,
(error) => {
- throw new OasstError(error.message ?? error, error.error_code, error?.response?.status || -1);
+ const err = error?.response?.data;
+ throw new OasstError(err?.message ?? error, err?.errorCode, error?.response?.httpStatusCode || -1);
}
);
diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts
index 42c6cf79..ee004ba9 100644
--- a/website/src/lib/auth.ts
+++ b/website/src/lib/auth.ts
@@ -21,14 +21,14 @@ const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiRe
* Wraps any API Route handler and verifies that the user has the appropriate
* role before running the handler. Returns a 403 otherwise.
*/
-const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
+const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, token: JWT) => void) => {
return async (req: NextApiRequest, res: NextApiResponse) => {
const token = await getToken({ req });
if (!token || token.role !== role) {
res.status(403).end();
return;
}
- return handler(req, res);
+ return handler(req, res, token);
};
};
diff --git a/website/src/lib/constants.ts b/website/src/lib/constants.ts
new file mode 100644
index 00000000..a260fa28
--- /dev/null
+++ b/website/src/lib/constants.ts
@@ -0,0 +1,43 @@
+import {
+ useCreateAssistantReply,
+ useCreateInitialPrompt,
+ useCreatePrompterReply,
+} from "src/hooks/tasks/useCreateReply";
+import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
+import {
+ useLabelAssistantReplyTask,
+ useLabelInitialPromptTask,
+ useLabelPrompterReplyTask,
+} from "src/hooks/tasks/useLabelingTask";
+import {
+ useRankAssistantRepliesTask,
+ useRankInitialPromptsTask,
+ useRankPrompterRepliesTask,
+} from "src/hooks/tasks/useRankReplies";
+import { TaskApiHooks } from "src/types/Hooks";
+import { TaskType } from "src/types/Task";
+
+export const ERROR_CODES = {
+ TASK_REQUESTED_TYPE_NOT_AVAILABLE: 1006,
+ TASK_INVALID_REQUEST_TYPE: 1000,
+ TASK_ACK_FAILED: 1001,
+ TASK_NACK_FAILED: 1002,
+ TASK_INVALID_RESPONSE_TYPE: 1003,
+ TASK_INTERACTION_REQUEST_FAILED: 1004,
+ TASK_GENERATION_FAILED: 1005,
+ TASK_AVAILABILITY_QUERY_FAILED: 1007,
+ TASK_MESSAGE_TOO_LONG: 1008,
+};
+
+export const taskApiHooks: TaskApiHooks = {
+ [TaskType.random]: useGenericTaskAPI,
+ [TaskType.assistant_reply]: useCreateAssistantReply,
+ [TaskType.initial_prompt]: useCreateInitialPrompt,
+ [TaskType.label_assistant_reply]: useLabelAssistantReplyTask,
+ [TaskType.label_initial_prompt]: useLabelInitialPromptTask,
+ [TaskType.label_prompter_reply]: useLabelPrompterReplyTask,
+ [TaskType.prompter_reply]: useCreatePrompterReply,
+ [TaskType.rank_assistant_replies]: useRankAssistantRepliesTask,
+ [TaskType.rank_initial_prompts]: useRankInitialPromptsTask,
+ [TaskType.rank_prompter_replies]: useRankPrompterRepliesTask,
+};
diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts
index bd263400..b9a9489e 100644
--- a/website/src/lib/oasst_api_client.ts
+++ b/website/src/lib/oasst_api_client.ts
@@ -1,4 +1,4 @@
-import type { Message } from "src/types/Conversation";
+import type { EmojiOp, Message } from "src/types/Conversation";
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import type { AvailableTasks } from "src/types/Task";
import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users";
@@ -18,10 +18,16 @@ export class OasstError {
export class OasstApiClient {
oasstApiUrl: string;
oasstApiKey: string;
+ userHeaders: Record = {};
- constructor(oasstApiUrl: string, oasstApiKey: string) {
+ constructor(oasstApiUrl: string, oasstApiKey: string, user?: BackendUserCore) {
this.oasstApiUrl = oasstApiUrl;
this.oasstApiKey = oasstApiKey;
+ if (user) {
+ this.userHeaders = {
+ "X-OASST-USER": `${user.auth_method}:${user.id}`,
+ };
+ }
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
@@ -76,6 +82,27 @@ export class OasstApiClient {
return this.post("/api/v1/tasks/availability", user);
}
+ /**
+ * Returns the `Message`s associated with `user_id` in the backend.
+ */
+ async fetch_message(message_id: string, user: BackendUserCore): Promise {
+ return this.get(`/api/v1/messages/${message_id}?username=${user.id}&auth_method=${user.auth_method}`);
+ }
+
+ /**
+ * Send a report about a message
+ */
+ async send_report(message_id: string, user: BackendUserCore, text: string) {
+ return this.post("/api/v1/text_labels", {
+ type: "text_labels",
+ message_id,
+ labels: [], // Not yet implemented
+ text,
+ is_report: true,
+ user,
+ });
+ }
+
/**
* Returns the message stats from the backend.
*/
@@ -144,7 +171,7 @@ export class OasstApiClient {
time_frame: LeaderboardTimeFrame,
{ limit = 20 }: { limit?: number }
): Promise {
- return this.get(`/api/v1/leaderboards/${time_frame}`, { limit });
+ return this.get(`/api/v1/leaderboards/${time_frame}`, { max_count: limit });
}
/**
@@ -154,6 +181,17 @@ export class OasstApiClient {
return this.post(`/api/v1/tasks/availability?lang=${lang}`, user);
}
+ /**
+ * Add/remove an emoji on a message for a user
+ */
+ async set_user_message_emoji(message_id: string, user: BackendUserCore, emoji: string, op: EmojiOp): Promise {
+ await this.post(`/api/v1/messages/${message_id}/emoji`, {
+ user,
+ emoji,
+ op,
+ });
+ }
+
private async post(path: string, body: unknown) {
return this.request("POST", path, {
body: JSON.stringify(body),
@@ -183,9 +221,10 @@ export class OasstApiClient {
method,
...init,
headers: {
+ ...init?.headers,
+ ...this.userHeaders,
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
- ...init?.headers,
},
});
@@ -195,8 +234,7 @@ export class OasstApiClient {
if (resp.status >= 300) {
const errorText = await resp.text();
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- let error: any;
+ let error;
try {
error = JSON.parse(errorText);
} catch (e) {
@@ -207,8 +245,24 @@ export class OasstApiClient {
return await resp.json();
}
+
+ fetch_my_messages(user: BackendUserCore) {
+ const params = new URLSearchParams({
+ username: user.id,
+ auth_method: user.auth_method,
+ });
+ return this.get(`/api/v1/messages?${params}`);
+ }
+
+ fetch_recent_messages() {
+ return this.get(`/api/v1/messages`);
+ }
+
+ fetch_message_children(messageId: string) {
+ return this.get(`/api/v1/messages/${messageId}/children`);
+ }
+
+ fetch_conversation(messageId: string) {
+ return this.get(`/api/v1/messages/${messageId}/conversation`);
+ }
}
-
-const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
-
-export { oasstApiClient };
diff --git a/website/src/lib/oasst_client_factory.ts b/website/src/lib/oasst_client_factory.ts
new file mode 100644
index 00000000..9f9bd657
--- /dev/null
+++ b/website/src/lib/oasst_client_factory.ts
@@ -0,0 +1,11 @@
+import { JWT } from "next-auth/jwt";
+import { OasstApiClient } from "src/lib/oasst_api_client";
+import { getBackendUserCore } from "src/lib/users";
+import { BackendUserCore } from "src/types/Users";
+
+export const createApiClientFromUser = (user: BackendUserCore) =>
+ new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY, user);
+
+export const createApiClient = async (token: JWT) => createApiClientFromUser(await getBackendUserCore(token.sub));
+
+export const userlessApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx
index b53bb7c0..a68bca16 100644
--- a/website/src/pages/admin/manage_user/[id].tsx
+++ b/website/src/pages/admin/manage_user/[id].tsx
@@ -10,7 +10,7 @@ import { getAdminLayout } from "src/components/Layout";
import { Role, RoleSelect } from "src/components/RoleSelect";
import { UserMessagesCell } from "src/components/UserMessagesCell";
import { post } from "src/lib/api";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { userlessApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import useSWRMutation from "swr/mutation";
@@ -113,7 +113,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType {
+ // NOTE: why are we using a dummy user here?
const dummyUser = {
id: "__dummy_user__",
display_name: "Dummy User",
auth_method: "local",
};
+ const oasstApiClient = createApiClientFromUser(dummyUser);
const [tasksAvailabilityOutcome, statsOutcome, treeManagerOutcome] = await Promise.allSettled([
oasstApiClient.fetch_tasks_availability(dummyUser),
oasstApiClient.fetch_stats(),
diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts
index 341ec736..c71159ad 100644
--- a/website/src/pages/api/admin/update_user.ts
+++ b/website/src/pages/api/admin/update_user.ts
@@ -1,22 +1,19 @@
import { withRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
/**
* Update's the user's data in the database. Accessible only to admins.
*/
-const handler = withRole("admin", async (req, res) => {
+const handler = withRole("admin", async (req, res, token) => {
const { id, auth_method, user_id, notes, role } = req.body;
+ const oasstApiClient = await createApiClient(token);
// If the user is authorized by the web, update their role.
if (auth_method === "local") {
await prisma.user.update({
- where: {
- id,
- },
- data: {
- role,
- },
+ where: { id },
+ data: { role },
});
}
// Tell the backend the user's enabled or not enabled status.
diff --git a/website/src/pages/api/admin/user_messages.ts b/website/src/pages/api/admin/user_messages.ts
index 236afa2d..0223e8e3 100644
--- a/website/src/pages/api/admin/user_messages.ts
+++ b/website/src/pages/api/admin/user_messages.ts
@@ -1,12 +1,13 @@
import { withRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import type { Message } from "src/types/Conversation";
/**
* Returns the messages recorded by the backend for a user.
*/
-const handler = withRole("admin", async (req, res) => {
+const handler = withRole("admin", async (req, res, token) => {
const { user } = req.query;
+ const oasstApiClient = await createApiClient(token);
const messages: Message[] = await oasstApiClient.fetch_user_messages(user as string);
res.status(200).json(messages);
});
diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts
index d10c91b0..eae0f072 100644
--- a/website/src/pages/api/admin/users.ts
+++ b/website/src/pages/api/admin/users.ts
@@ -1,5 +1,5 @@
import { withRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { FetchUsersParams } from "src/types/Users";
@@ -17,9 +17,10 @@ const PAGE_SIZE = 20;
* - `direction`: Either "forward" or "backward" representing the pagination
* direction.
*/
-const handler = withRole("admin", async (req, res) => {
+const handler = withRole("admin", async (req, res, token) => {
const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query;
+ const oasstApiClient = await createApiClient(token);
// First, get all the users according to the backend.
const { items: all_users, ...rest } = await oasstApiClient.fetch_users({
searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"],
diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts
index 79dcc8b9..218e3864 100644
--- a/website/src/pages/api/available_tasks.ts
+++ b/website/src/pages/api/available_tasks.ts
@@ -1,9 +1,10 @@
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const user = await getBackendUserCore(token.sub);
+ const oasstApiClient = createApiClientFromUser(user);
const userLanguage = getUserLanguage(req);
const availableTasks = await oasstApiClient.fetch_available_tasks(user, userLanguage);
res.status(200).json(availableTasks);
diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts
index be91e7b4..fad1d8a6 100644
--- a/website/src/pages/api/leaderboard.ts
+++ b/website/src/pages/api/leaderboard.ts
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
/**
* Returns the set of valid labels that can be applied to messages.
*/
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
+ const oasstApiClient = await createApiClient(token);
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day;
const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number });
res.status(200).json(info);
diff --git a/website/src/pages/api/messages/[id]/children.ts b/website/src/pages/api/messages/[id]/children.ts
index 4a184c11..0185b615 100644
--- a/website/src/pages/api/messages/[id]/children.ts
+++ b/website/src/pages/api/messages/[id]/children.ts
@@ -1,18 +1,10 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClient } from "src/lib/oasst_client_factory";
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
-
- const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- });
- const messages = await messagesRes.json();
-
- // Send recieved messages to the client.
+ const client = await createApiClient(token);
+ const messages = await client.fetch_message_children(id as string);
res.status(200).json(messages);
});
diff --git a/website/src/pages/api/messages/[id]/conversation.ts b/website/src/pages/api/messages/[id]/conversation.ts
index 0c401883..fc284e40 100644
--- a/website/src/pages/api/messages/[id]/conversation.ts
+++ b/website/src/pages/api/messages/[id]/conversation.ts
@@ -1,18 +1,10 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClient } from "src/lib/oasst_client_factory";
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
-
- const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- });
- const messages = await messagesRes.json();
-
- // Send recieved messages to the client.
+ const client = await createApiClient(token);
+ const messages = await client.fetch_conversation(id as string);
res.status(200).json(messages);
});
diff --git a/website/src/pages/api/messages/[id]/emoji.ts b/website/src/pages/api/messages/[id]/emoji.ts
new file mode 100644
index 00000000..8d6257c1
--- /dev/null
+++ b/website/src/pages/api/messages/[id]/emoji.ts
@@ -0,0 +1,31 @@
+import { withoutRole } from "src/lib/auth";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
+import { getBackendUserCore } from "src/lib/users";
+
+const handler = withoutRole("banned", async (req, res, token) => {
+ const { id } = req.query;
+
+ if (!id) {
+ res.status(400).end();
+ return;
+ }
+
+ const messageId = id as string;
+
+ const { emoji, op } = req.body;
+
+ const user = await getBackendUserCore(token.sub);
+ const oasstApiClient = createApiClientFromUser(user);
+ try {
+ await oasstApiClient.set_user_message_emoji(messageId, user, emoji, op);
+ } catch (err) {
+ console.error(JSON.stringify(err));
+ return res.status(500).json(err);
+ }
+
+ // Get updated emoji
+ const message = await oasstApiClient.fetch_message(messageId, user);
+ res.status(200).json({ emojis: message.emojis, user_emojis: message.user_emojis });
+});
+
+export default handler;
diff --git a/website/src/pages/api/messages/[id]/index.ts b/website/src/pages/api/messages/[id]/index.ts
index d9c2bc4a..b3361a2d 100644
--- a/website/src/pages/api/messages/[id]/index.ts
+++ b/website/src/pages/api/messages/[id]/index.ts
@@ -1,18 +1,12 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
+import { getBackendUserCore } from "src/lib/users";
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
-
- const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- });
- const message = await messageRes.json();
-
- // Send recieved messages to the client.
+ const user = await getBackendUserCore(token.sub);
+ const client = createApiClientFromUser(user);
+ const message = await client.fetch_message(id as string, user);
res.status(200).json(message);
});
diff --git a/website/src/pages/api/messages/[id]/parent.ts b/website/src/pages/api/messages/[id]/parent.ts
index ba332be9..9b8c64d4 100644
--- a/website/src/pages/api/messages/[id]/parent.ts
+++ b/website/src/pages/api/messages/[id]/parent.ts
@@ -1,6 +1,8 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClient, createApiClientFromUser } from "src/lib/oasst_client_factory";
+import { getBackendUserCore } from "src/lib/users";
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
const { id } = req.query;
if (!id) {
@@ -8,32 +10,16 @@ const handler = withoutRole("banned", async (req, res) => {
return;
}
- const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- });
-
- const message = await messageRes.json();
+ const user = await getBackendUserCore(token.sub);
+ const client = createApiClientFromUser(user);
+ const message = await client.fetch_message(id as string, user);
if (!message.parent_id) {
res.status(404).end();
return;
}
- const parentRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${message.parent_id}`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- });
-
- const parent = await parentRes.json();
-
- // Send recieved messages to the client.
+ const parent = await client.fetch_message(message.parent_id, user);
res.status(200).json(parent);
});
diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts
index cb9728fe..fbcaee3c 100644
--- a/website/src/pages/api/messages/index.ts
+++ b/website/src/pages/api/messages/index.ts
@@ -1,15 +1,9 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClient } from "src/lib/oasst_client_factory";
-const handler = withoutRole("banned", async (req, res) => {
- const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- },
- });
- const messages = await messagesRes.json();
-
- // Send recieved messages to the client.
+const handler = withoutRole("banned", async (req, res, token) => {
+ const client = await createApiClient(token);
+ const messages = await client.fetch_recent_messages();
res.status(200).json(messages);
});
diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts
index 6f39aad1..bffe2fb6 100644
--- a/website/src/pages/api/messages/user.ts
+++ b/website/src/pages/api/messages/user.ts
@@ -1,23 +1,11 @@
import { withoutRole } from "src/lib/auth";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
- //TODO: add params if needed
const user = await getBackendUserCore(token.sub);
- const params = new URLSearchParams({
- username: user.id,
- auth_method: user.auth_method,
- });
-
- const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
- method: "GET",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- },
- });
- const messages = await messagesRes.json();
-
- // Send recieved messages to the client.
+ const client = createApiClientFromUser(user);
+ const messages = await client.fetch_my_messages(user);
res.status(200).json(messages);
});
diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts
index 360b8faa..34993f4f 100644
--- a/website/src/pages/api/new_task/[task_type].ts
+++ b/website/src/pages/api/new_task/[task_type].ts
@@ -1,5 +1,5 @@
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -17,6 +17,7 @@ const handler = withoutRole("banned", async (req, res, token) => {
const userLanguage = getUserLanguage(req);
const user = await getBackendUserCore(token.sub);
+ const oasstApiClient = createApiClientFromUser(user);
let task;
try {
task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage);
diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts
index 0084fffa..2e3f4fa1 100644
--- a/website/src/pages/api/reject_task.ts
+++ b/website/src/pages/api/reject_task.ts
@@ -1,19 +1,21 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
-const handler = withoutRole("banned", async (req, res) => {
+const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, reason } = req.body;
- const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
+ const [oasstApiClient, registeredTask] = await Promise.all([
+ createApiClient(token),
+ prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
+ ]);
- const task = registeredTask.task as Prisma.JsonObject;
- const id = task.id as string;
+ const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
// Update the backend with the rejection
- await oasstApiClient.nackTask(id, reason);
+ await oasstApiClient.nackTask(taskId, reason);
// Send the results to the client.
res.status(200).json({});
diff --git a/website/src/pages/api/report.ts b/website/src/pages/api/report.ts
new file mode 100644
index 00000000..9b904df4
--- /dev/null
+++ b/website/src/pages/api/report.ts
@@ -0,0 +1,25 @@
+import { withoutRole } from "src/lib/auth";
+import { createApiClientFromUser } from "src/lib/oasst_client_factory";
+import { getBackendUserCore } from "src/lib/users";
+
+/**
+ * Adds a report for a message
+ *
+ */
+const handler = withoutRole("banned", async (req, res, token) => {
+ // Parse out the local message_id, and the interaction contents.
+ const { message_id, text } = req.body;
+
+ const user = await getBackendUserCore(token.sub);
+ const oasstApiClient = createApiClientFromUser(user);
+ try {
+ await oasstApiClient.send_report(message_id, user, text);
+ } catch (err) {
+ console.error(JSON.stringify(err));
+ return res.status(500).json(err);
+ }
+
+ res.status(200).end();
+});
+
+export default handler;
diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts
index 93a150d5..3c54a89f 100644
--- a/website/src/pages/api/set_label.ts
+++ b/website/src/pages/api/set_label.ts
@@ -5,8 +5,10 @@ import { withoutRole } from "src/lib/auth";
*
*/
const handler = withoutRole("banned", async (req, res, token) => {
+ // TODO: move to oasst_api_client
// Parse out the local message_id, and the interaction contents.
- const { message_id, label_map, text } = req.body;
+ const { message_id, label_map } = req.body;
+
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
method: "POST",
headers: {
@@ -17,7 +19,8 @@ const handler = withoutRole("banned", async (req, res, token) => {
type: "text_labels",
message_id: message_id,
labels: label_map,
- text: text,
+ text: "", // used only in reporting
+ is_report: false,
user: {
id: token.sub,
display_name: token.name || token.email,
diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts
index 6f08d640..1b4f2eda 100644
--- a/website/src/pages/api/update_task.ts
+++ b/website/src/pages/api/update_task.ts
@@ -1,6 +1,6 @@
import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
import prisma from "src/lib/prismadb";
import { getBackendUserCore, getUserLanguage } from "src/lib/users";
@@ -18,13 +18,18 @@ const handler = withoutRole("banned", async (req, res, token) => {
// Parse out the local task ID and the interaction contents.
const { id: frontendId, content, update_type } = req.body;
- // Record that the user has done meaningful work and is no longer new.
- await prisma.user.update({ where: { id: token.sub }, data: { isNew: false } });
+ // do in parallel since they are independent
+ const [_, registeredTask, oasstApiClient] = await Promise.all([
+ // Record that the user has done meaningful work and is no longer new.
+ prisma.user.update({ where: { id: token.sub }, data: { isNew: false } }),
+ // Accept the task so that we can complete it, this will probably go away soon.
+ prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } }),
+ // Create client for upcoming requests
+ createApiClient(token),
+ ]);
+
+ const taskId = (registeredTask.task as Prisma.JsonObject).id as string;
- // Accept the task so that we can complete it, this will probably go away soon.
- const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
- const task = registeredTask.task as Prisma.JsonObject;
- const taskId = task.id as string;
await oasstApiClient.ackTask(taskId, registeredTask.id);
// Log the interaction locally to create our user_post_id needed by the Task
diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts
index 5f61ff2f..dca92d90 100644
--- a/website/src/pages/api/valid_labels.ts
+++ b/website/src/pages/api/valid_labels.ts
@@ -1,11 +1,12 @@
import { withoutRole } from "src/lib/auth";
-import { oasstApiClient } from "src/lib/oasst_api_client";
+import { createApiClient } from "src/lib/oasst_client_factory";
/**
* Returns the set of valid labels that can be applied to messages.
*/
-const handler = withoutRole("banned", async (req, res) => {
- const valid_labels = await oasstApiClient.fetch_valid_text();
+const handler = withoutRole("banned", async (req, res, token) => {
+ const client = await createApiClient(token);
+ const valid_labels = await client.fetch_valid_text();
res.status(200).json(valid_labels);
});
diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx
index 1c83eb23..0f3095a9 100644
--- a/website/src/pages/create/assistant_reply.tsx
+++ b/website/src/pages/create/assistant_reply.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const AssistantReply = () => {
- const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Reply as Assistant
-
-
-
- >
- );
-};
+const AssistantReply = () => ;
AssistantReply.getLayout = getDashboardLayout;
diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx
index 639df68f..c73f2e5d 100644
--- a/website/src/pages/create/initial_prompt.tsx
+++ b/website/src/pages/create/initial_prompt.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const InitialPrompt = () => {
- const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Initial Prompt
-
-
-
- >
- );
-};
+const InitialPrompt = () => ;
InitialPrompt.getLayout = getDashboardLayout;
diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx
index 5898439c..39218476 100644
--- a/website/src/pages/create/user_reply.tsx
+++ b/website/src/pages/create/user_reply.tsx
@@ -1,33 +1,10 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const UserReply = () => {
- const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
+const PrompterReply = () => ;
- if (isLoading) {
- return ;
- }
+PrompterReply.getLayout = getDashboardLayout;
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Reply as User
-
-
-
- >
- );
-};
-
-UserReply.getLayout = getDashboardLayout;
-
-export default UserReply;
+export default PrompterReply;
diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx
index da79d92f..dd4c1df9 100644
--- a/website/src/pages/evaluate/rank_assistant_replies.tsx
+++ b/website/src/pages/evaluate/rank_assistant_replies.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const RankAssistantReplies = () => {
- const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Rank Assistant Replies
-
-
-
- >
- );
-};
+const RankAssistantReplies = () => ;
RankAssistantReplies.getLayout = getDashboardLayout;
diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx
index f23fc0ed..1eb91289 100644
--- a/website/src/pages/evaluate/rank_initial_prompts.tsx
+++ b/website/src/pages/evaluate/rank_initial_prompts.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const RankInitialPrompts = () => {
- const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Rank Initial Prompts
-
-
-
- >
- );
-};
+const RankInitialPrompts = () => ;
RankInitialPrompts.getLayout = getDashboardLayout;
diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx
index cee82b87..a1caba59 100644
--- a/website/src/pages/evaluate/rank_user_replies.tsx
+++ b/website/src/pages/evaluate/rank_user_replies.tsx
@@ -1,33 +1,10 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const RankUserReplies = () => {
- const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
+const RankPrompterReplies = () => ;
- if (isLoading) {
- return ;
- }
+RankPrompterReplies.getLayout = getDashboardLayout;
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Rank User Replies
-
-
-
- >
- );
-};
-
-RankUserReplies.getLayout = getDashboardLayout;
-
-export default RankUserReplies;
+export default RankPrompterReplies;
diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx
index 07a6cb1c..8be12b41 100644
--- a/website/src/pages/label/label_assistant_reply.tsx
+++ b/website/src/pages/label/label_assistant_reply.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const LabelAssistantReply = () => {
- const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Label Assistant Reply
-
-
-
- >
- );
-};
+const LabelAssistantReply = () => ;
LabelAssistantReply.getLayout = getDashboardLayout;
diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx
index 8735044f..c5fed344 100644
--- a/website/src/pages/label/label_initial_prompt.tsx
+++ b/website/src/pages/label/label_initial_prompt.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const LabelInitialPrompt = () => {
- const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Label Initial Prompt
-
-
-
- >
- );
-};
+const LabelInitialPrompt = () => ;
LabelInitialPrompt.getLayout = getDashboardLayout;
diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx
index 17164e11..33e8aba4 100644
--- a/website/src/pages/label/label_prompter_reply.tsx
+++ b/website/src/pages/label/label_prompter_reply.tsx
@@ -1,32 +1,9 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
+import { TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
-const LabelPrompterReply = () => {
- const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
-
- if (isLoading) {
- return ;
- }
-
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Label Prompter Reply
-
-
-
- >
- );
-};
+const LabelPrompterReply = () => ;
LabelPrompterReply.getLayout = getDashboardLayout;
diff --git a/website/src/pages/messages/[id]/index.tsx b/website/src/pages/messages/[id]/index.tsx
index 51c28c42..158d28e8 100644
--- a/website/src/pages/messages/[id]/index.tsx
+++ b/website/src/pages/messages/[id]/index.tsx
@@ -1,5 +1,6 @@
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
+import { useTranslation } from "next-i18next";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
import { getDashboardLayout } from "src/components/Layout";
import { MessageLoading } from "src/components/Loading/MessageLoading";
@@ -10,6 +11,7 @@ import { Message } from "src/types/Conversation";
import useSWRImmutable from "swr/immutable";
const MessageDetail = ({ id }: { id: string }) => {
+ const { t } = useTranslation(["message", "common"]);
const backgroundColor = useColorModeValue("white", "gray.800");
const { isLoading: isLoadingParent, data: parent } = useSWRImmutable(`/api/messages/${id}/parent`, get);
@@ -20,7 +22,7 @@ const MessageDetail = ({ id }: { id: string }) => {
return (
<>
- Open Assistant
+ {t("common:title")}
{
<>
- Parent
+ {t("parent")}
-
+
>
@@ -54,7 +56,7 @@ MessageDetail.getLayout = (page) => getDashboardLayout(page);
export const getServerSideProps = async ({ locale, query }) => ({
props: {
id: query.id,
- ...(await serverSideTranslations(locale, ["common"])),
+ ...(await serverSideTranslations(locale, ["common", "message"])),
},
});
diff --git a/website/src/pages/tasks/random.tsx b/website/src/pages/tasks/random.tsx
index f1c04d2c..cd7ed458 100644
--- a/website/src/pages/tasks/random.tsx
+++ b/website/src/pages/tasks/random.tsx
@@ -1,34 +1,10 @@
-import Head from "next/head";
-import { TaskEmptyState } from "src/components/EmptyState";
import { getDashboardLayout } from "src/components/Layout";
-import { LoadingScreen } from "src/components/Loading/LoadingScreen";
-import { Task } from "src/components/Tasks/Task";
-import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
+import { TaskPage } from "src/components/TaskPage/TaskPage";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
import { TaskType } from "src/types/Task";
-const RandomTask = () => {
- const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random);
+const Random = () => ;
- if (isLoading) {
- return ;
- }
+Random.getLayout = getDashboardLayout;
- if (tasks.length === 0) {
- return ;
- }
-
- return (
- <>
-
- Random Task
-
-
-
- >
- );
-};
-
-RandomTask.getLayout = (page) => getDashboardLayout(page);
-
-export default RandomTask;
+export default Random;
diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts
index 6f51a870..8b258a25 100644
--- a/website/src/types/Conversation.ts
+++ b/website/src/types/Conversation.ts
@@ -1,7 +1,22 @@
-export interface Message {
+export type EmojiOp = "add" | "remove" | "toggle";
+
+export interface MessageEmoji {
+ name: string;
+ count: number;
+}
+
+export interface MessageEmojis {
+ emojis: { [emoji: string]: number };
+ user_emojis: string[];
+}
+
+export interface Message extends MessageEmojis {
+ id: string;
text: string;
is_assistant: boolean;
- id: string;
+ lang: string;
+ created_date: string; // iso date string
+ parent_id: string;
frontend_message_id?: string;
}
diff --git a/website/src/types/Hooks.ts b/website/src/types/Hooks.ts
new file mode 100644
index 00000000..8fd9aa4f
--- /dev/null
+++ b/website/src/types/Hooks.ts
@@ -0,0 +1,26 @@
+import { MutatorCallback, MutatorOptions } from "swr";
+
+import { BaseTask, TaskResponse, TaskType } from "./Task";
+
+type ConcreteTaskResponse = TaskResponse;
+type TaskError = { errorCode: number; message: string };
+
+type Trigger = (
+ extraArgument?: unknown,
+ options?: MutatorOptions
+) => Promise;
+
+type Reset = (
+ data?: ConcreteTaskResponse | Promise | MutatorCallback,
+ opts?: boolean | MutatorOptions
+) => Promise;
+
+type TaskAPIHook = {
+ tasks: TaskResponse[];
+ isLoading: boolean;
+ error: TaskError;
+ trigger: Trigger;
+ reset: Reset;
+};
+
+export type TaskApiHooks = Record TaskAPIHook>;
diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts
index 12e37db0..7ae48138 100644
--- a/website/src/types/Task.ts
+++ b/website/src/types/Task.ts
@@ -1,4 +1,4 @@
-export const enum TaskType {
+export enum TaskType {
initial_prompt = "initial_prompt",
assistant_reply = "assistant_reply",
prompter_reply = "prompter_reply",
diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts
index a791916e..5fbc84c7 100644
--- a/website/src/types/Tasks.ts
+++ b/website/src/types/Tasks.ts
@@ -1,4 +1,4 @@
-import { Conversation } from "./Conversation";
+import { Conversation, Message } from "./Conversation";
import { BaseTask, TaskType } from "./Task";
export interface CreateInitialPromptTask extends BaseTask {
@@ -33,29 +33,39 @@ export interface RankPrompterRepliesTask extends BaseTask {
replies: string[];
}
-export interface LabelAssistantReplyTask extends BaseTask {
+export interface Label {
+ display_text: string;
+ help_text: string;
+ name: string;
+ widget: "flag" | "yes_no" | "likert";
+}
+
+export interface BaseLabelTask extends BaseTask {
+ message_id: string;
+ labels: Label[];
+ valid_labels: string[];
+ disposition: "spam" | "quality";
+ mode: "simple" | "full";
+ mandatory_labels?: string[];
+}
+
+export interface LabelAssistantReplyTask extends BaseLabelTask {
type: TaskType.label_assistant_reply;
- message_id: string;
conversation: Conversation;
+ reply_message: Message;
reply: string;
- valid_labels: string[];
- mode: "simple" | "full";
- mandatory_labels?: string[];
}
-export interface LabelPrompterReplyTask extends BaseTask {
+export interface LabelPrompterReplyTask extends BaseLabelTask {
type: TaskType.label_prompter_reply;
- message_id: string;
conversation: Conversation;
+ reply_message: Message;
reply: string;
- valid_labels: string[];
- mode: "simple" | "full";
- mandatory_labels?: string[];
}
-export interface LabelInitialPromptTask extends BaseTask {
+export interface LabelInitialPromptTask extends BaseLabelTask {
type: TaskType.label_initial_prompt;
- message_id: string;
- valid_labels: string[];
prompt: string;
}
+
+export type LabelTaskType = LabelAssistantReplyTask | LabelPrompterReplyTask | LabelInitialPromptTask;
diff --git a/website/styles/Theme/colors.tsx b/website/styles/Theme/colors.tsx
index 7f82ebce..26609185 100644
--- a/website/styles/Theme/colors.tsx
+++ b/website/styles/Theme/colors.tsx
@@ -5,6 +5,7 @@ export const colors = {
div: "white",
text: "black",
highlight: "blue.400",
+ active: "blue.400",
},
dark: {
bg: "gray.900",
@@ -12,5 +13,6 @@ export const colors = {
div: "gray.700",
text: "gray.200",
highlight: "blue.500",
+ active: "blue.500",
},
};
diff --git a/website/types/i18next.d.ts b/website/types/i18next.d.ts
index 873101fd..0a2cf10a 100644
--- a/website/types/i18next.d.ts
+++ b/website/types/i18next.d.ts
@@ -1,9 +1,9 @@
-import "i18next";
-
import type common from "public/locales/en/common.json";
import type dashboard from "public/locales/en/dashboard.json";
import type index from "public/locales/en/index.json";
import type leaderboard from "public/locales/en/leaderboard.json";
+import type message from "public/locales/en/message.json";
+import type labelling from "public/locales/en/labelling.json";
import type tasks from "public/locales/en/tasks.json";
declare module "i18next" {
@@ -14,6 +14,8 @@ declare module "i18next" {
index: typeof index;
leaderboard: typeof leaderboard;
tasks: typeof tasks;
+ message: typeof message;
+ labelling: typeof labelling;
};
}
}