Adding a new web api path that returns valid labels and then fetching from that within FlaggableElement. This allows FlaggableElement to fetch all its own data and remove the need to pipe labels through a series of components

This commit is contained in:
Keith Stevens
2023-01-10 20:57:42 +09:00
parent 0968d8add6
commit 747c3501d1
14 changed files with 63 additions and 65 deletions
+25 -14
View File
@@ -22,9 +22,11 @@ import {
useId,
} from "@chakra-ui/react";
import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
import { useState } from "react";
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { colors } from "styles/Theme/colors";
import useSWR from "swr";
import useSWRMutation from "swr/mutation";
interface textFlagLabels {
@@ -34,16 +36,27 @@ interface textFlagLabels {
}
export const FlaggableElement = (props) => {
const [labels, setLabels] = useState([]);
const [checkboxValues, setCheckboxValues] = useState([]);
const [sliderValues, setSliderValues] = useState([]);
const [isEditing, setIsEditing] = useBoolean();
const flaggable_labels = props.flaggable_labels;
const TEXT_LABEL_FLAGS =
flaggable_labels?.valid_labels?.map((valid_label) => {
return {
attributeName: valid_label.name,
labelText: valid_label.display_text,
additionalExplanation: valid_label.help_text,
};
}) || [];
const { data, isLoading } = useSWR("/api/valid_labels", fetcher);
useEffect(() => {
if (isLoading) {
return;
}
const { valid_labels } = data;
const newLabels = valid_labels.map((valid_label) => ({
attributeName: valid_label.name,
labelText: valid_label.display_text,
additionalExplanation: valid_label.help_text,
}));
setSliderValues(new Array(newLabels.length).fill(1));
setCheckboxValues(new Array(newLabels.length).fill(false));
setLabels(newLabels);
}, [data, isLoading]);
const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off();
@@ -52,7 +65,7 @@ export const FlaggableElement = (props) => {
const submitResponse = () => {
const label_map: Map<string, number> = new Map();
TEXT_LABEL_FLAGS.forEach((flag, i) => {
labels.forEach((flag, i) => {
if (checkboxValues[i]) {
label_map.set(flag.attributeName, sliderValues[i]);
}
@@ -64,8 +77,6 @@ export const FlaggableElement = (props) => {
text: props.text,
});
};
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
const handleCheckboxState = (isChecked, idx) => {
setCheckboxValues(
@@ -110,7 +121,7 @@ export const FlaggableElement = (props) => {
<PopoverCloseButton />
</div>
<PopoverBody>
{TEXT_LABEL_FLAGS.map((option, i) => (
{labels.map((option, i) => (
<FlagCheckbox
option={option}
key={i}
+2 -17
View File
@@ -2,29 +2,14 @@ import { Grid } from "@chakra-ui/react";
import { forwardRef, useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { Message } from "src/types/Conversation";
import { ValidLabel } from "src/types/Task";
import { FlaggableElement } from "./FlaggableElement";
export const Messages = ({
messages,
post_id,
valid_labels,
}: {
messages: Message[];
post_id: string;
valid_labels: ValidLabel[];
}) => {
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id, text } = messageProps;
return (
<FlaggableElement
text={text}
post_id={post_id}
message_id={message_id}
key={i + text}
flaggable_labels={valid_labels}
>
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<MessageView {...messageProps} />
</FlaggableElement>
);
@@ -1,11 +1,11 @@
import { Stack, StackDivider } from "@chakra-ui/react";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
export function MessageTable({ messages, valid_labels }) {
export function MessageTable({ messages }) {
return (
<Stack divider={<StackDivider />} spacing="4">
{messages.map((item, idx) => (
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} valid_labels={valid_labels} />
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} />
))}
</Stack>
);
@@ -2,7 +2,6 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import NextLink from "next/link";
import { FlaggableElement } from "src/components/FlaggableElement";
import type { ValidLabel } from "src/types/Task";
interface Message {
text: string;
@@ -12,14 +11,13 @@ interface Message {
interface MessageTableEntryProps {
item: Message;
idx: number;
valid_labels: ValidLabel[];
}
export function MessageTableEntry(props: MessageTableEntryProps) {
const { item, idx, valid_labels } = props;
const { item, idx } = props;
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
return (
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`} flaggable_labels={valid_labels}>
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
<HStack>
<Avatar
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
@@ -64,7 +64,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
<Flex justifyContent="center" pb="2">
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
<Box px={isFirstOrOnly ? "2" : "0"}>
<MessageTableEntry item={message} idx={1} valid_labels={[]} />
<MessageTableEntry item={message} idx={1} />
</Box>
</Box>
</Flex>
@@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
<HStack {...MessageStackProps}>
{children.map((item, idx) => (
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
<MessageTableEntry item={item} idx={idx * 2} valid_labels={[]} />
<MessageTableEntry item={item} idx={idx * 2} />
</Box>
))}
</HStack>
+1 -4
View File
@@ -17,7 +17,6 @@ export interface CreateTaskProps {
}
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const [inputText, setInputText] = useState("");
const submitResponse = (task: { id: string }) => {
@@ -41,9 +40,7 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m
<>
<h5 className="text-lg font-semibold">{taskType.label}</h5>
<p className="text-lg py-1">{taskType.overview}</p>
{task.conversation ? (
<Messages messages={task.conversation.messages} post_id={task.id} valid_labels={valid_labels} />
) : null}
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
</>
<>
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
@@ -33,7 +33,6 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
messages = messages.map((message, index) => ({ ...message, id: index }));
}
const valid_labels = tasks[0].valid_labels;
const sortables = tasks[0].task.replies ? "replies" : "prompts";
return (
@@ -43,7 +42,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
<p className="text-lg py-1">
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
</p>
{messages ? <MessageTable messages={messages} valid_labels={valid_labels} /> : null}
{messages ? <MessageTable messages={messages} /> : null}
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
</SurveyCard>
@@ -23,7 +23,6 @@ const handler = async (req, res) => {
// Fetch the new task.
const task = await oasstApiClient.fetchTask(task_type, token);
const valid_labels = await oasstApiClient.fetch_valid_text();
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -37,9 +36,6 @@ const handler = async (req, res) => {
},
});
// Add the valid labels that can be used to flag messages in this Task
registeredTask["valid_labels"] = valid_labels;
// Send the results to the client.
res.status(200).json(registeredTask);
};
+23
View File
@@ -0,0 +1,23 @@
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
/**
* TODO
*/
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
// Fetch the new task.
const valid_labels = await oasstApiClient.fetch_valid_text();
// Send the results to the client.
res.status(200).json(valid_labels);
};
export default handler;
@@ -16,7 +16,6 @@ const LabelAssistantReply = () => {
}
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: true, message_id: task.message_id },
@@ -26,7 +25,7 @@ const LabelAssistantReply = () => {
<LabelTask
title="Label Assistant Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} valid_labels={valid_labels} />}
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
@@ -16,7 +16,6 @@ const LabelPrompterReply = () => {
}
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: false, message_id: task.message_id },
@@ -26,7 +25,7 @@ const LabelPrompterReply = () => {
<LabelTask
title="Label Prompter Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} valid_labels={valid_labels} />}
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
+1 -1
View File
@@ -41,7 +41,7 @@ const MessageDetail = ({ id }) => {
Parent
</Text>
<Box rounded="lg" p="2">
<MessageTableEntry item={parent} idx={1} valid_labels={[]} />
<MessageTableEntry item={parent} idx={1} />
</Box>
</>
)}
+2 -10
View File
@@ -52,11 +52,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{receivedMessages ? (
<MessageTable messages={messages} valid_labels={[]} />
) : (
<CircularProgress isIndeterminate />
)}
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
<Box>
@@ -70,11 +66,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{receivedUserMessages ? (
<MessageTable messages={userMessages} valid_labels={[]} />
) : (
<CircularProgress isIndeterminate />
)}
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
</SimpleGrid>
-1
View File
@@ -27,5 +27,4 @@ export interface TaskResponse<Task extends BaseTask> {
id: string;
userId: string;
task: Task;
valid_labels: ValidLabel[];
}