mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Label Initial Prompt
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
"rules": {
|
||||
"unused-imports/no-unused-imports": "warn",
|
||||
"simple-import-sort/imports": "warn",
|
||||
"simple-import-sort/exports": "warn"
|
||||
"simple-import-sort/exports": "warn",
|
||||
"eqeqeq": "warn"
|
||||
},
|
||||
"plugins": ["simple-import-sort", "unused-imports"]
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import Link from "next/link";
|
||||
|
||||
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
|
||||
|
||||
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate];
|
||||
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
|
||||
|
||||
export const TaskOption = () => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
@@ -12,9 +12,9 @@ export const TaskOption = () => {
|
||||
<Box className="flex flex-col gap-14" fontFamily="inter">
|
||||
{displayTaskCategories.map((category, categoryIndex) => (
|
||||
<div key={categoryIndex}>
|
||||
<Text className="text-2xl font-bold pb-4">{TaskCategory[category]}</Text>
|
||||
<Text className="text-2xl font-bold pb-4">{category}</Text>
|
||||
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
|
||||
{TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => (
|
||||
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
|
||||
@@ -118,7 +118,8 @@ export const FlaggableElement = (props) => {
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
function FlagCheckbox(props: {
|
||||
|
||||
export function FlagCheckbox(props: {
|
||||
option: textFlagLabels;
|
||||
idx: number;
|
||||
checkboxValues: boolean[];
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Grid } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
@@ -8,29 +9,30 @@ export interface Message {
|
||||
is_assistant: boolean;
|
||||
}
|
||||
|
||||
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
|
||||
if (colorMode === "light") {
|
||||
return isAssistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return isAssistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
};
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { text } = messageProps;
|
||||
|
||||
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} key={i + text}>
|
||||
<div
|
||||
key={i + text}
|
||||
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
|
||||
>
|
||||
{text}
|
||||
</div>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = ({ is_assistant, text }: Message) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const bgColor = useMemo(() => {
|
||||
if (colorMode === "light") {
|
||||
return is_assistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return is_assistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
}, [colorMode, is_assistant]);
|
||||
|
||||
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
|
||||
};
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
export enum TaskCategory {
|
||||
Create,
|
||||
Evaluate,
|
||||
Create = "Create",
|
||||
Evaluate = "Evaluate",
|
||||
Label = "Label",
|
||||
}
|
||||
|
||||
export const TaskTypes = [
|
||||
// create
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
@@ -31,6 +33,7 @@ export const TaskTypes = [
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the assistant`s reply",
|
||||
},
|
||||
// evaluate
|
||||
{
|
||||
label: "Rank User Replies",
|
||||
category: TaskCategory.Evaluate,
|
||||
@@ -52,4 +55,12 @@ export const TaskTypes = [
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
type: "rank_initial_prompts",
|
||||
},
|
||||
// label
|
||||
{
|
||||
label: "Label Initial Prompt",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
message_id: string;
|
||||
prompt: string;
|
||||
type: string;
|
||||
valid_labels: string[];
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelingTask = <LabelingTaskType>({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => {
|
||||
type ConcreteTaskResponse = TaskResponse<LabelingTaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<Array<ConcreteTaskResponse>>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, {
|
||||
onSuccess: (data: ConcreteTaskResponse) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length === 0 && !isLoading && !error) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks, isLoading, mutate, error]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (reply) => {
|
||||
const newTask: ConcreteTaskResponse = await reply.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, labels: Record<string, string>) =>
|
||||
trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
|
||||
return { tasks, isLoading, submit, error, reset: mutate };
|
||||
};
|
||||
@@ -42,7 +42,7 @@ export class OasstApiClient {
|
||||
} catch (e) {
|
||||
throw new OasstError(errorText, 0, resp.status);
|
||||
}
|
||||
throw new OasstError(error.message, error.error_code, resp.status);
|
||||
throw new OasstError(error.message ?? error, error.error_code, resp.status);
|
||||
}
|
||||
|
||||
return await resp.json();
|
||||
|
||||
@@ -35,7 +35,12 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
let newTask;
|
||||
try {
|
||||
newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
} catch (err) {
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
// Stores the new task with our database.
|
||||
const newRegisteredTask = await prisma.registeredTask.create({
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelingTask<LabelInitialPromptTask>({
|
||||
taskApiEndpoint: "label_initial_prompt",
|
||||
});
|
||||
|
||||
const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => {
|
||||
const labels = task.valid_labels.reduce((obj, label, i) => {
|
||||
obj[label] = sliderValues[i].toString();
|
||||
return obj;
|
||||
}, {} as Record<string, string>);
|
||||
|
||||
submit(id, task.message_id, task.prompt, labels);
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
|
||||
<p className="text-lg py-1">Provide labels for the following prompt</p>
|
||||
<MessageView text={task.prompt} is_assistant />
|
||||
</>
|
||||
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
|
||||
</TwoColumnsWithCards>
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={reset} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelInitialPrompt;
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
|
||||
interface CheckboxSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user