diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index a7157c4a..d82b5750 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -22,9 +22,11 @@ import { useId, } from "@chakra-ui/react"; import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import { colors } from "styles/Theme/colors"; +import useSWR from "swr"; import useSWRMutation from "swr/mutation"; interface textFlagLabels { @@ -34,16 +36,27 @@ interface textFlagLabels { } export const FlaggableElement = (props) => { + const [labels, setLabels] = useState([]); + const [checkboxValues, setCheckboxValues] = useState([]); + const [sliderValues, setSliderValues] = useState([]); const [isEditing, setIsEditing] = useBoolean(); - const flaggable_labels = props.flaggable_labels; - const TEXT_LABEL_FLAGS = - flaggable_labels?.valid_labels?.map((valid_label) => { - return { - attributeName: valid_label.name, - labelText: valid_label.display_text, - additionalExplanation: valid_label.help_text, - }; - }) || []; + + const { data, isLoading } = useSWR("/api/valid_labels", fetcher); + useEffect(() => { + if (isLoading) { + return; + } + const { valid_labels } = data; + const newLabels = valid_labels.map((valid_label) => ({ + attributeName: valid_label.name, + labelText: valid_label.display_text, + additionalExplanation: valid_label.help_text, + })); + setSliderValues(new Array(newLabels.length).fill(1)); + setCheckboxValues(new Array(newLabels.length).fill(false)); + setLabels(newLabels); + }, [data, isLoading]); + const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { setIsEditing.off(); @@ -52,7 +65,7 @@ export const FlaggableElement = (props) => { const submitResponse = () => { const label_map: Map = new Map(); - TEXT_LABEL_FLAGS.forEach((flag, i) => { + labels.forEach((flag, i) => { if (checkboxValues[i]) { label_map.set(flag.attributeName, sliderValues[i]); } @@ -64,8 +77,6 @@ export const FlaggableElement = (props) => { text: props.text, }); }; - const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false)); - const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1)); const handleCheckboxState = (isChecked, idx) => { setCheckboxValues( @@ -110,7 +121,7 @@ export const FlaggableElement = (props) => { - {TEXT_LABEL_FLAGS.map((option, i) => ( + {labels.map((option, i) => ( { +export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { const items = messages.map((messageProps: Message, i: number) => { const { message_id, text } = messageProps; return ( - + ); diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index 872b79f1..7797c6a3 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -1,11 +1,11 @@ import { Stack, StackDivider } from "@chakra-ui/react"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; -export function MessageTable({ messages, valid_labels }) { +export function MessageTable({ messages }) { return ( } spacing="4"> {messages.map((item, idx) => ( - + ))} ); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index e9e8775a..9fad7262 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -2,7 +2,6 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import NextLink from "next/link"; import { FlaggableElement } from "src/components/FlaggableElement"; -import type { ValidLabel } from "src/types/Task"; interface Message { text: string; @@ -12,14 +11,13 @@ interface Message { interface MessageTableEntryProps { item: Message; idx: number; - valid_labels: ValidLabel[]; } export function MessageTableEntry(props: MessageTableEntryProps) { - const { item, idx, valid_labels } = props; + const { item, idx } = props; const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900"); return ( - + - + @@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { {children.map((item, idx) => ( - + ))} diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index a424315a..4d8daea0 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -17,7 +17,6 @@ export interface CreateTaskProps { } export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => { const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const [inputText, setInputText] = useState(""); const submitResponse = (task: { id: string }) => { @@ -41,9 +40,7 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m <>
{taskType.label}

{taskType.overview}

- {task.conversation ? ( - - ) : null} + {task.conversation ? : null} <>
{taskType.instruction}
diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 61ed3889..d0a1f404 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -33,7 +33,6 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla messages = messages.map((message, index) => ({ ...message, id: index })); } - const valid_labels = tasks[0].valid_labels; const sortables = tasks[0].task.replies ? "replies" : "prompts"; return ( @@ -43,7 +42,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla

Given the following {sortables}, sort them from best to worst, best being first, worst being last.

- {messages ? : null} + {messages ? : null} diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 80334f76..9f3be55c 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -23,7 +23,6 @@ const handler = async (req, res) => { // Fetch the new task. const task = await oasstApiClient.fetchTask(task_type, token); - const valid_labels = await oasstApiClient.fetch_valid_text(); // Store the task and link it to the user.. const registeredTask = await prisma.registeredTask.create({ @@ -37,9 +36,6 @@ const handler = async (req, res) => { }, }); - // Add the valid labels that can be used to flag messages in this Task - registeredTask["valid_labels"] = valid_labels; - // Send the results to the client. res.status(200).json(registeredTask); }; diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts new file mode 100644 index 00000000..ab788c05 --- /dev/null +++ b/website/src/pages/api/valid_labels.ts @@ -0,0 +1,23 @@ +import { getToken } from "next-auth/jwt"; +import { oasstApiClient } from "src/lib/oasst_api_client"; + +/** + * TODO + */ +const handler = async (req, res) => { + const token = await getToken({ req }); + + // Return nothing if the user isn't registered. + if (!token) { + res.status(401).end(); + return; + } + + // Fetch the new task. + const valid_labels = await oasstApiClient.fetch_valid_text(); + + // Send the results to the client. + res.status(200).json(valid_labels); +}; + +export default handler; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 99c10f56..59a7bbcc 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -16,7 +16,6 @@ const LabelAssistantReply = () => { } const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: true, message_id: task.message_id }, @@ -26,7 +25,7 @@ const LabelAssistantReply = () => { } + messages={} inputs={} controls={ { } const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: false, message_id: task.message_id }, @@ -26,7 +25,7 @@ const LabelPrompterReply = () => { } + messages={} inputs={} controls={ { Parent - + )} diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index ed48d47b..9cdc2ac5 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -52,11 +52,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedMessages ? ( - - ) : ( - - )} + {receivedMessages ? : } @@ -70,11 +66,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedUserMessages ? ( - - ) : ( - - )} + {receivedUserMessages ? : } diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 6975fa14..7f614ab6 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -27,5 +27,4 @@ export interface TaskResponse { id: string; userId: string; task: Task; - valid_labels: ValidLabel[]; }