mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge pull request #523 from LAION-AI/label_prompter_reply
Add Label Prompter Reply Task
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import { Progress } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
export const LoadingScreen = ({ text }) => {
|
||||
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ export interface Message {
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { message_id } = messageProps;
|
||||
const { text } = messageProps;
|
||||
const { message_id, text } = messageProps;
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
|
||||
<MessageView {...messageProps} />
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export const LabelTask = ({
|
||||
title,
|
||||
desc,
|
||||
messages,
|
||||
inputs,
|
||||
controls,
|
||||
}: {
|
||||
title: string;
|
||||
desc: string;
|
||||
messages: ReactNode;
|
||||
inputs: ReactNode;
|
||||
controls: ReactNode;
|
||||
}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
const card = useMemo(
|
||||
() => (
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{title}</h5>
|
||||
<p className="text-lg py-1">{desc}</p>
|
||||
{messages}
|
||||
</>
|
||||
),
|
||||
[title, desc, messages]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
{card}
|
||||
{inputs}
|
||||
</TwoColumnsWithCards>
|
||||
{controls}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
interface LabelSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -63,4 +63,11 @@ export const TaskTypes = [
|
||||
pathname: "/label/label_initial_prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
{
|
||||
label: "Label Prompter Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
type: "label_prompter_reply",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
|
||||
export interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
|
||||
type ConcreteTaskResponse = TaskResponse<TaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
|
||||
"/api/new_task/" + taskApiEndpoint,
|
||||
fetcher,
|
||||
{
|
||||
onSuccess: (data) => setTasks([data]),
|
||||
}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (response) => {
|
||||
const newTask: ConcreteTaskResponse = await response.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
return { tasks, isLoading, trigger, error, reset: mutate };
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
type: "label_initial_prompt";
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelInitialPromptTask = () => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelInitialPromptTask>("label_initial_prompt");
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
@@ -0,0 +1,31 @@
|
||||
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
export interface LabelPrompterReplyTask {
|
||||
id: string;
|
||||
type: "label_prompter_reply";
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
|
||||
|
||||
export const useLabelPrompterReplyTask = () => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelPrompterReplyTask>("label_prompter_reply");
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
@@ -1,52 +0,0 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
message_id: string;
|
||||
prompt: string;
|
||||
type: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
|
||||
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
|
||||
onSuccess: (data: ConcreteTaskResponse) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (reply) => {
|
||||
const newTask: ConcreteTaskResponse = await reply.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
|
||||
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
|
||||
return { tasks, isLoading, submit, error, reset: mutate };
|
||||
};
|
||||
@@ -1,8 +1,8 @@
|
||||
export { default } from "next-auth/middleware";
|
||||
|
||||
/**
|
||||
* Guards all pages under `/grading` and redirects them to the sign in page.
|
||||
* Guards these pages and redirects them to the sign in page.
|
||||
*/
|
||||
export const config = {
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
|
||||
};
|
||||
|
||||
@@ -1,113 +1,38 @@
|
||||
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelInitialPrompt";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
|
||||
taskApiEndpoint: "label_initial_prompt",
|
||||
});
|
||||
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
|
||||
|
||||
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
|
||||
const labels = task.valid_labels.reduce((obj, label, i) => {
|
||||
obj[label] = sliderValues[i].toString();
|
||||
return obj;
|
||||
}, {} as Record<string, string>);
|
||||
|
||||
submit(id, task.message_id, task.prompt, labels);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
|
||||
<p className="text-lg py-1">Provide labels for the following prompt</p>
|
||||
<MessageView text={task.prompt} is_assistant message_id={task.message_id} />
|
||||
</>
|
||||
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
|
||||
</TwoColumnsWithCards>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
|
||||
</div>
|
||||
<LabelTask
|
||||
title="Label Initial Prompt"
|
||||
desc="Provide labels for the following prompt"
|
||||
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelInitialPrompt;
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
|
||||
interface CheckboxSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
];
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelPrompterReply;
|
||||
Reference in New Issue
Block a user