From 747c3501d104e9c930d3952eaee8fb216edfd6d1 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 10 Jan 2023 20:57:42 +0900 Subject: [PATCH] Adding a new web api path that returns valid labels and then fetching from that within FlaggableElement. This allows FlaggableElement to fetch all its own data and remove the need to pipe labels through a series of components --- website/src/components/FlaggableElement.tsx | 39 ++++++++++++------- website/src/components/Messages.tsx | 19 +-------- .../src/components/Messages/MessageTable.tsx | 4 +- .../components/Messages/MessageTableEntry.tsx | 6 +-- .../Messages/MessageWithChildren.tsx | 4 +- website/src/components/Tasks/CreateTask.tsx | 5 +-- website/src/components/Tasks/EvaluateTask.tsx | 3 +- website/src/pages/api/new_task/[task_type].ts | 4 -- website/src/pages/api/valid_labels.ts | 23 +++++++++++ .../src/pages/label/label_assistant_reply.tsx | 3 +- .../src/pages/label/label_prompter_reply.tsx | 3 +- website/src/pages/messages/[id]/index.tsx | 2 +- website/src/pages/messages/index.tsx | 12 +----- website/src/types/Task.ts | 1 - 14 files changed, 63 insertions(+), 65 deletions(-) create mode 100644 website/src/pages/api/valid_labels.ts diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index a7157c4a..d82b5750 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -22,9 +22,11 @@ import { useId, } from "@chakra-ui/react"; import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import fetcher from "src/lib/fetcher"; import poster from "src/lib/poster"; import { colors } from "styles/Theme/colors"; +import useSWR from "swr"; import useSWRMutation from "swr/mutation"; interface textFlagLabels { @@ -34,16 +36,27 @@ interface textFlagLabels { } export const FlaggableElement = (props) => { + const [labels, setLabels] = useState([]); + const [checkboxValues, setCheckboxValues] = useState([]); + const [sliderValues, setSliderValues] = useState([]); const [isEditing, setIsEditing] = useBoolean(); - const flaggable_labels = props.flaggable_labels; - const TEXT_LABEL_FLAGS = - flaggable_labels?.valid_labels?.map((valid_label) => { - return { - attributeName: valid_label.name, - labelText: valid_label.display_text, - additionalExplanation: valid_label.help_text, - }; - }) || []; + + const { data, isLoading } = useSWR("/api/valid_labels", fetcher); + useEffect(() => { + if (isLoading) { + return; + } + const { valid_labels } = data; + const newLabels = valid_labels.map((valid_label) => ({ + attributeName: valid_label.name, + labelText: valid_label.display_text, + additionalExplanation: valid_label.help_text, + })); + setSliderValues(new Array(newLabels.length).fill(1)); + setCheckboxValues(new Array(newLabels.length).fill(false)); + setLabels(newLabels); + }, [data, isLoading]); + const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { setIsEditing.off(); @@ -52,7 +65,7 @@ export const FlaggableElement = (props) => { const submitResponse = () => { const label_map: Map = new Map(); - TEXT_LABEL_FLAGS.forEach((flag, i) => { + labels.forEach((flag, i) => { if (checkboxValues[i]) { label_map.set(flag.attributeName, sliderValues[i]); } @@ -64,8 +77,6 @@ export const FlaggableElement = (props) => { text: props.text, }); }; - const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false)); - const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1)); const handleCheckboxState = (isChecked, idx) => { setCheckboxValues( @@ -110,7 +121,7 @@ export const FlaggableElement = (props) => { - {TEXT_LABEL_FLAGS.map((option, i) => ( + {labels.map((option, i) => ( { +export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { const items = messages.map((messageProps: Message, i: number) => { const { message_id, text } = messageProps; return ( - + ); diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index 872b79f1..7797c6a3 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -1,11 +1,11 @@ import { Stack, StackDivider } from "@chakra-ui/react"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; -export function MessageTable({ messages, valid_labels }) { +export function MessageTable({ messages }) { return ( } spacing="4"> {messages.map((item, idx) => ( - + ))} ); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index e9e8775a..9fad7262 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -2,7 +2,6 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import NextLink from "next/link"; import { FlaggableElement } from "src/components/FlaggableElement"; -import type { ValidLabel } from "src/types/Task"; interface Message { text: string; @@ -12,14 +11,13 @@ interface Message { interface MessageTableEntryProps { item: Message; idx: number; - valid_labels: ValidLabel[]; } export function MessageTableEntry(props: MessageTableEntryProps) { - const { item, idx, valid_labels } = props; + const { item, idx } = props; const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900"); return ( - + - + @@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { {children.map((item, idx) => ( - + ))} diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index a424315a..4d8daea0 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -17,7 +17,6 @@ export interface CreateTaskProps { } export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => { const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const [inputText, setInputText] = useState(""); const submitResponse = (task: { id: string }) => { @@ -41,9 +40,7 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m <>
{taskType.label}

{taskType.overview}

- {task.conversation ? ( - - ) : null} + {task.conversation ? : null} <>
{taskType.instruction}
diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 61ed3889..d0a1f404 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -33,7 +33,6 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla messages = messages.map((message, index) => ({ ...message, id: index })); } - const valid_labels = tasks[0].valid_labels; const sortables = tasks[0].task.replies ? "replies" : "prompts"; return ( @@ -43,7 +42,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla

Given the following {sortables}, sort them from best to worst, best being first, worst being last.

- {messages ? : null} + {messages ? : null} diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index 80334f76..9f3be55c 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -23,7 +23,6 @@ const handler = async (req, res) => { // Fetch the new task. const task = await oasstApiClient.fetchTask(task_type, token); - const valid_labels = await oasstApiClient.fetch_valid_text(); // Store the task and link it to the user.. const registeredTask = await prisma.registeredTask.create({ @@ -37,9 +36,6 @@ const handler = async (req, res) => { }, }); - // Add the valid labels that can be used to flag messages in this Task - registeredTask["valid_labels"] = valid_labels; - // Send the results to the client. res.status(200).json(registeredTask); }; diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts new file mode 100644 index 00000000..ab788c05 --- /dev/null +++ b/website/src/pages/api/valid_labels.ts @@ -0,0 +1,23 @@ +import { getToken } from "next-auth/jwt"; +import { oasstApiClient } from "src/lib/oasst_api_client"; + +/** + * TODO + */ +const handler = async (req, res) => { + const token = await getToken({ req }); + + // Return nothing if the user isn't registered. + if (!token) { + res.status(401).end(); + return; + } + + // Fetch the new task. + const valid_labels = await oasstApiClient.fetch_valid_text(); + + // Send the results to the client. + res.status(200).json(valid_labels); +}; + +export default handler; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 99c10f56..59a7bbcc 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -16,7 +16,6 @@ const LabelAssistantReply = () => { } const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: true, message_id: task.message_id }, @@ -26,7 +25,7 @@ const LabelAssistantReply = () => { } + messages={} inputs={} controls={ { } const task = tasks[0].task; - const valid_labels = tasks[0].valid_labels; const messages: Message[] = [ ...task.conversation.messages, { text: task.reply, is_assistant: false, message_id: task.message_id }, @@ -26,7 +25,7 @@ const LabelPrompterReply = () => { } + messages={} inputs={} controls={ { Parent - + )} diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index ed48d47b..9cdc2ac5 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -52,11 +52,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedMessages ? ( - - ) : ( - - )} + {receivedMessages ? : } @@ -70,11 +66,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {receivedUserMessages ? ( - - ) : ( - - )} + {receivedUserMessages ? : } diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 6975fa14..7f614ab6 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -27,5 +27,4 @@ export interface TaskResponse { id: string; userId: string; task: Task; - valid_labels: ValidLabel[]; }