Merge pull request #523 from LAION-AI/label_prompter_reply

Add Label Prompter Reply Task
This commit is contained in:
Keith Stevens
2023-01-08 19:45:53 +09:00
committed by GitHub
11 changed files with 272 additions and 153 deletions
@@ -1,7 +1,7 @@
import { Progress } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
export const LoadingScreen = ({ text }) => {
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
const { colorMode } = useColorMode();
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
+1 -2
View File
@@ -12,8 +12,7 @@ export interface Message {
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id } = messageProps;
const { text } = messageProps;
const { message_id, text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<MessageView {...messageProps} />
+100
View File
@@ -0,0 +1,100 @@
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { colors } from "styles/Theme/colors";
export const LabelTask = ({
title,
desc,
messages,
inputs,
controls,
}: {
title: string;
desc: string;
messages: ReactNode;
inputs: ReactNode;
controls: ReactNode;
}) => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
const card = useMemo(
() => (
<>
<h5 className="text-lg font-semibold">{title}</h5>
<p className="text-lg py-1">{desc}</p>
{messages}
</>
),
[title, desc, messages]
);
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
{card}
{inputs}
</TwoColumnsWithCards>
{controls}
</div>
);
};
// TODO: consolidate with FlaggableElement
interface LabelSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
@@ -63,4 +63,11 @@ export const TaskTypes = [
pathname: "/label/label_initial_prompt",
type: "label_initial_prompt",
},
{
label: "Label Prompter Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_prompter_reply",
type: "label_prompter_reply",
},
];
@@ -0,0 +1,42 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
export interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
"/api/new_task/" + taskApiEndpoint,
fetcher,
{
onSuccess: (data) => setTasks([data]),
}
);
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (response) => {
const newTask: ConcreteTaskResponse = await response.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
return { tasks, isLoading, trigger, error, reset: mutate };
};
@@ -0,0 +1,24 @@
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
export interface LabelInitialPromptTask {
id: string;
type: "label_initial_prompt";
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelInitialPromptTask>("label_initial_prompt");
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -0,0 +1,31 @@
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
export interface LabelPrompterReplyTask {
id: string;
type: "label_prompter_reply";
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelPrompterReplyTask>("label_prompter_reply");
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
-52
View File
@@ -1,52 +0,0 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export interface LabelInitialPromptTask {
id: string;
message_id: string;
prompt: string;
type: string;
valid_labels: string[];
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
onSuccess: (data: ConcreteTaskResponse) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (reply) => {
const newTask: ConcreteTaskResponse = await reply.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
return { tasks, isLoading, submit, error, reset: mutate };
};
+2 -2
View File
@@ -1,8 +1,8 @@
export { default } from "next-auth/middleware";
/**
* Guards all pages under `/grading` and redirects them to the sign in page.
* Guards these pages and redirects them to the sign in page.
*/
export const config = {
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
};
@@ -1,113 +1,38 @@
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useEffect, useId, useState } from "react";
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
import { colors } from "styles/Theme/colors";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelInitialPrompt";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
taskApiEndpoint: "label_initial_prompt",
});
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
const labels = task.valid_labels.reduce((obj, label, i) => {
obj[label] = sliderValues[i].toString();
return obj;
}, {} as Record<string, string>);
submit(id, task.message_id, task.prompt, labels);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
<p className="text-lg py-1">Provide labels for the following prompt</p>
<MessageView text={task.prompt} is_assistant message_id={task.message_id} />
</>
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
</div>
<LabelTask
title="Label Initial Prompt"
desc="Provide labels for the following prompt"
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelInitialPrompt;
// TODO: consolidate with FlaggableElement
interface CheckboxSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
@@ -0,0 +1,43 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: false, message_id: task.message_id },
];
return (
<LabelTask
title="Label Prompter Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelPrompterReply;