Label Initial Prompt

This commit is contained in:
AbdBarho
2023-01-07 11:25:25 +01:00
parent 4ecdc57621
commit 53814d77ab
9 changed files with 210 additions and 25 deletions
+2 -1
View File
@@ -8,7 +8,8 @@
"rules": {
"unused-imports/no-unused-imports": "warn",
"simple-import-sort/imports": "warn",
"simple-import-sort/exports": "warn"
"simple-import-sort/exports": "warn",
"eqeqeq": "warn"
},
"plugins": ["simple-import-sort", "unused-imports"]
}
@@ -3,7 +3,7 @@ import Link from "next/link";
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate];
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
export const TaskOption = () => {
const backgroundColor = useColorModeValue("white", "gray.700");
@@ -12,9 +12,9 @@ export const TaskOption = () => {
<Box className="flex flex-col gap-14" fontFamily="inter">
{displayTaskCategories.map((category, categoryIndex) => (
<div key={categoryIndex}>
<Text className="text-2xl font-bold pb-4">{TaskCategory[category]}</Text>
<Text className="text-2xl font-bold pb-4">{category}</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => (
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
+2 -1
View File
@@ -118,7 +118,8 @@ export const FlaggableElement = (props) => {
</Popover>
);
};
function FlagCheckbox(props: {
export function FlagCheckbox(props: {
option: textFlagLabels;
idx: number;
checkboxValues: boolean[];
+18 -16
View File
@@ -1,5 +1,6 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { FlaggableElement } from "./FlaggableElement";
@@ -8,29 +9,30 @@ export interface Message {
is_assistant: boolean;
}
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
if (colorMode === "light") {
return isAssistant ? "bg-slate-800" : "bg-sky-900";
} else {
return isAssistant ? "bg-black" : "bg-sky-900";
}
};
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const { colorMode } = useColorMode();
const items = messages.map((messageProps: Message, i: number) => {
const { text } = messageProps;
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
return (
<FlaggableElement text={text} post_id={post_id} key={i + text}>
<div
key={i + text}
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
>
{text}
</div>
<MessageView {...messageProps} />
</FlaggableElement>
);
});
// Maybe also show a legend of the colors?
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = ({ is_assistant, text }: Message) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
if (colorMode === "light") {
return is_assistant ? "bg-slate-800" : "bg-sky-900";
} else {
return is_assistant ? "bg-black" : "bg-sky-900";
}
}, [colorMode, is_assistant]);
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
};
+13 -2
View File
@@ -1,9 +1,11 @@
export enum TaskCategory {
Create,
Evaluate,
Create = "Create",
Evaluate = "Evaluate",
Label = "Label",
}
export const TaskTypes = [
// create
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
@@ -31,6 +33,7 @@ export const TaskTypes = [
overview: "Given the following conversation, provide an adequate reply",
instruction: "Provide the assistant`s reply",
},
// evaluate
{
label: "Rank User Replies",
category: TaskCategory.Evaluate,
@@ -52,4 +55,12 @@ export const TaskTypes = [
pathname: "/evaluate/rank_initial_prompts",
type: "rank_initial_prompts",
},
// label
{
label: "Label Initial Prompt",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_initial_prompt",
type: "label_initial_prompt",
},
];
+52
View File
@@ -0,0 +1,52 @@
import { useEffect, useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export interface LabelInitialPromptTask {
id: string;
message_id: string;
prompt: string;
type: string;
valid_labels: string[];
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
onSuccess: (data: ConcreteTaskResponse) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length === 0 && !isLoading && !error) {
mutate();
}
}, [tasks, isLoading, mutate, error]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (reply) => {
const newTask: ConcreteTaskResponse = await reply.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
return { tasks, isLoading, submit, error, reset: mutate };
};
+1 -1
View File
@@ -42,7 +42,7 @@ export class OasstApiClient {
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message, error.error_code, resp.status);
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
+6 -1
View File
@@ -35,7 +35,12 @@ const handler = async (req, res) => {
},
});
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
let newTask;
try {
newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
} catch (err) {
return res.status(500).json(err);
}
// Stores the new task with our database.
const newRegisteredTask = await prisma.registeredTask.create({
@@ -0,0 +1,113 @@
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useEffect, useId, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
import { colors } from "styles/Theme/colors";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
taskApiEndpoint: "label_initial_prompt",
});
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
const labels = task.valid_labels.reduce((obj, label, i) => {
obj[label] = sliderValues[i].toString();
return obj;
}, {} as Record<string, string>);
submit(id, task.message_id, task.prompt, labels);
};
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
if (isLoading) {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
const task = tasks[0].task;
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
<>
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
<p className="text-lg py-1">Provide labels for the following prompt</p>
<MessageView text={task.prompt} is_assistant />
</>
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
</div>
);
};
export default LabelInitialPrompt;
// TODO: consolidate with FlaggableElement
interface CheckboxSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}