mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Adding a new web api path that returns valid labels and then fetching from that within FlaggableElement. This allows FlaggableElement to fetch all its own data and remove the need to pipe labels through a series of components
This commit is contained in:
@@ -22,9 +22,11 @@ import {
|
||||
useId,
|
||||
} from "@chakra-ui/react";
|
||||
import { FlagIcon, QuestionMarkCircleIcon } from "@heroicons/react/20/solid";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import useSWR from "swr";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface textFlagLabels {
|
||||
@@ -34,16 +36,27 @@ interface textFlagLabels {
|
||||
}
|
||||
|
||||
export const FlaggableElement = (props) => {
|
||||
const [labels, setLabels] = useState([]);
|
||||
const [checkboxValues, setCheckboxValues] = useState([]);
|
||||
const [sliderValues, setSliderValues] = useState([]);
|
||||
const [isEditing, setIsEditing] = useBoolean();
|
||||
const flaggable_labels = props.flaggable_labels;
|
||||
const TEXT_LABEL_FLAGS =
|
||||
flaggable_labels?.valid_labels?.map((valid_label) => {
|
||||
return {
|
||||
attributeName: valid_label.name,
|
||||
labelText: valid_label.display_text,
|
||||
additionalExplanation: valid_label.help_text,
|
||||
};
|
||||
}) || [];
|
||||
|
||||
const { data, isLoading } = useSWR("/api/valid_labels", fetcher);
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
return;
|
||||
}
|
||||
const { valid_labels } = data;
|
||||
const newLabels = valid_labels.map((valid_label) => ({
|
||||
attributeName: valid_label.name,
|
||||
labelText: valid_label.display_text,
|
||||
additionalExplanation: valid_label.help_text,
|
||||
}));
|
||||
setSliderValues(new Array(newLabels.length).fill(1));
|
||||
setCheckboxValues(new Array(newLabels.length).fill(false));
|
||||
setLabels(newLabels);
|
||||
}, [data, isLoading]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/set_label", poster, {
|
||||
onSuccess: () => {
|
||||
setIsEditing.off();
|
||||
@@ -52,7 +65,7 @@ export const FlaggableElement = (props) => {
|
||||
|
||||
const submitResponse = () => {
|
||||
const label_map: Map<string, number> = new Map();
|
||||
TEXT_LABEL_FLAGS.forEach((flag, i) => {
|
||||
labels.forEach((flag, i) => {
|
||||
if (checkboxValues[i]) {
|
||||
label_map.set(flag.attributeName, sliderValues[i]);
|
||||
}
|
||||
@@ -64,8 +77,6 @@ export const FlaggableElement = (props) => {
|
||||
text: props.text,
|
||||
});
|
||||
};
|
||||
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
|
||||
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
|
||||
|
||||
const handleCheckboxState = (isChecked, idx) => {
|
||||
setCheckboxValues(
|
||||
@@ -110,7 +121,7 @@ export const FlaggableElement = (props) => {
|
||||
<PopoverCloseButton />
|
||||
</div>
|
||||
<PopoverBody>
|
||||
{TEXT_LABEL_FLAGS.map((option, i) => (
|
||||
{labels.map((option, i) => (
|
||||
<FlagCheckbox
|
||||
option={option}
|
||||
key={i}
|
||||
|
||||
@@ -2,29 +2,14 @@ import { Grid } from "@chakra-ui/react";
|
||||
import { forwardRef, useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { ValidLabel } from "src/types/Task";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
export const Messages = ({
|
||||
messages,
|
||||
post_id,
|
||||
valid_labels,
|
||||
}: {
|
||||
messages: Message[];
|
||||
post_id: string;
|
||||
valid_labels: ValidLabel[];
|
||||
}) => {
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { message_id, text } = messageProps;
|
||||
return (
|
||||
<FlaggableElement
|
||||
text={text}
|
||||
post_id={post_id}
|
||||
message_id={message_id}
|
||||
key={i + text}
|
||||
flaggable_labels={valid_labels}
|
||||
>
|
||||
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Stack, StackDivider } from "@chakra-ui/react";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
|
||||
export function MessageTable({ messages, valid_labels }) {
|
||||
export function MessageTable({ messages }) {
|
||||
return (
|
||||
<Stack divider={<StackDivider />} spacing="4">
|
||||
{messages.map((item, idx) => (
|
||||
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} valid_labels={valid_labels} />
|
||||
<MessageTableEntry item={item} idx={idx} key={item.message_id || item.id} />
|
||||
))}
|
||||
</Stack>
|
||||
);
|
||||
|
||||
@@ -2,7 +2,6 @@ import { Avatar, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import NextLink from "next/link";
|
||||
import { FlaggableElement } from "src/components/FlaggableElement";
|
||||
import type { ValidLabel } from "src/types/Task";
|
||||
|
||||
interface Message {
|
||||
text: string;
|
||||
@@ -12,14 +11,13 @@ interface Message {
|
||||
interface MessageTableEntryProps {
|
||||
item: Message;
|
||||
idx: number;
|
||||
valid_labels: ValidLabel[];
|
||||
}
|
||||
export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
const { item, idx, valid_labels } = props;
|
||||
const { item, idx } = props;
|
||||
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
|
||||
|
||||
return (
|
||||
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`} flaggable_labels={valid_labels}>
|
||||
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
|
||||
<HStack>
|
||||
<Avatar
|
||||
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
|
||||
|
||||
@@ -64,7 +64,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
<Flex justifyContent="center" pb="2">
|
||||
<Box maxWidth="container.sm" flex="1" px={isFirstOrOnly ? [4, 6, 8, 9] : "0"}>
|
||||
<Box px={isFirstOrOnly ? "2" : "0"}>
|
||||
<MessageTableEntry item={message} idx={1} valid_labels={[]} />
|
||||
<MessageTableEntry item={message} idx={1} />
|
||||
</Box>
|
||||
</Box>
|
||||
</Flex>
|
||||
@@ -90,7 +90,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
<HStack {...MessageStackProps}>
|
||||
{children.map((item, idx) => (
|
||||
<Box maxWidth="container.sm" flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
<MessageTableEntry item={item} idx={idx * 2} valid_labels={[]} />
|
||||
<MessageTableEntry item={item} idx={idx * 2} />
|
||||
</Box>
|
||||
))}
|
||||
</HStack>
|
||||
|
||||
@@ -17,7 +17,6 @@ export interface CreateTaskProps {
|
||||
}
|
||||
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
@@ -41,9 +40,7 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
{task.conversation ? (
|
||||
<Messages messages={task.conversation.messages} post_id={task.id} valid_labels={valid_labels} />
|
||||
) : null}
|
||||
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
|
||||
|
||||
@@ -33,7 +33,6 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
|
||||
messages = messages.map((message, index) => ({ ...message, id: index }));
|
||||
}
|
||||
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const sortables = tasks[0].task.replies ? "replies" : "prompts";
|
||||
|
||||
return (
|
||||
@@ -43,7 +42,7 @@ export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgCla
|
||||
<p className="text-lg py-1">
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
{messages ? <MessageTable messages={messages} valid_labels={valid_labels} /> : null}
|
||||
{messages ? <MessageTable messages={messages} /> : null}
|
||||
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Fetch the new task.
|
||||
const task = await oasstApiClient.fetchTask(task_type, token);
|
||||
const valid_labels = await oasstApiClient.fetch_valid_text();
|
||||
|
||||
// Store the task and link it to the user..
|
||||
const registeredTask = await prisma.registeredTask.create({
|
||||
@@ -37,9 +36,6 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
// Add the valid labels that can be used to flag messages in this Task
|
||||
registeredTask["valid_labels"] = valid_labels;
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
|
||||
/**
|
||||
* TODO
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Fetch the new task.
|
||||
const valid_labels = await oasstApiClient.fetch_valid_text();
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(valid_labels);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -16,7 +16,6 @@ const LabelAssistantReply = () => {
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: true, message_id: task.message_id },
|
||||
@@ -26,7 +25,7 @@ const LabelAssistantReply = () => {
|
||||
<LabelTask
|
||||
title="Label Assistant Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} valid_labels={valid_labels} />}
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
|
||||
@@ -16,7 +16,6 @@ const LabelPrompterReply = () => {
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const valid_labels = tasks[0].valid_labels;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
@@ -26,7 +25,7 @@ const LabelPrompterReply = () => {
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} valid_labels={valid_labels} />}
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
|
||||
@@ -41,7 +41,7 @@ const MessageDetail = ({ id }) => {
|
||||
Parent
|
||||
</Text>
|
||||
<Box rounded="lg" p="2">
|
||||
<MessageTableEntry item={parent} idx={1} valid_labels={[]} />
|
||||
<MessageTableEntry item={parent} idx={1} />
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -52,11 +52,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{receivedMessages ? (
|
||||
<MessageTable messages={messages} valid_labels={[]} />
|
||||
) : (
|
||||
<CircularProgress isIndeterminate />
|
||||
)}
|
||||
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
@@ -70,11 +66,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{receivedUserMessages ? (
|
||||
<MessageTable messages={userMessages} valid_labels={[]} />
|
||||
) : (
|
||||
<CircularProgress isIndeterminate />
|
||||
)}
|
||||
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
|
||||
@@ -27,5 +27,4 @@ export interface TaskResponse<Task extends BaseTask> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: Task;
|
||||
valid_labels: ValidLabel[];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user