From 53814d77abc06a639bb7d8c7ea9ffecb6c58bc86 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sat, 7 Jan 2023 11:25:25 +0100 Subject: [PATCH] Label Initial Prompt --- website/.eslintrc.json | 3 +- .../src/components/Dashboard/TaskOption.tsx | 6 +- website/src/components/FlaggableElement.tsx | 3 +- website/src/components/Messages.tsx | 34 +++--- website/src/components/Tasks/TaskTypes.tsx | 15 ++- website/src/hooks/useLabelingTask.ts | 52 ++++++++ website/src/lib/oasst_api_client.ts | 2 +- website/src/pages/api/update_task.ts | 7 +- .../src/pages/label/label_initial_prompt.tsx | 113 ++++++++++++++++++ 9 files changed, 210 insertions(+), 25 deletions(-) create mode 100644 website/src/hooks/useLabelingTask.ts create mode 100644 website/src/pages/label/label_initial_prompt.tsx diff --git a/website/.eslintrc.json b/website/.eslintrc.json index 04b5d542..690c055c 100644 --- a/website/.eslintrc.json +++ b/website/.eslintrc.json @@ -8,7 +8,8 @@ "rules": { "unused-imports/no-unused-imports": "warn", "simple-import-sort/imports": "warn", - "simple-import-sort/exports": "warn" + "simple-import-sort/exports": "warn", + "eqeqeq": "warn" }, "plugins": ["simple-import-sort", "unused-imports"] } diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index 5e6ceb2f..1c070e17 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -3,7 +3,7 @@ import Link from "next/link"; import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes"; -const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate]; +const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label]; export const TaskOption = () => { const backgroundColor = useColorModeValue("white", "gray.700"); @@ -12,9 +12,9 @@ export const TaskOption = () => { {displayTaskCategories.map((category, categoryIndex) => (
- {TaskCategory[category]} + {category} - {TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => ( + {TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => ( { ); }; -function FlagCheckbox(props: { + +export function FlagCheckbox(props: { option: textFlagLabels; idx: number; checkboxValues: boolean[]; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index d3d7b3b8..7b69bc50 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,5 +1,6 @@ import { Grid } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; +import { useMemo } from "react"; import { FlaggableElement } from "./FlaggableElement"; @@ -8,29 +9,30 @@ export interface Message { is_assistant: boolean; } -const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => { - if (colorMode === "light") { - return isAssistant ? "bg-slate-800" : "bg-sky-900"; - } else { - return isAssistant ? "bg-black" : "bg-sky-900"; - } -}; - export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { - const { colorMode } = useColorMode(); + const items = messages.map((messageProps: Message, i: number) => { + const { text } = messageProps; - const items = messages.map(({ text, is_assistant }: Message, i: number) => { return ( -
- {text} -
+
); }); // Maybe also show a legend of the colors? return {items}; }; + +export const MessageView = ({ is_assistant, text }: Message) => { + const { colorMode } = useColorMode(); + + const bgColor = useMemo(() => { + if (colorMode === "light") { + return is_assistant ? "bg-slate-800" : "bg-sky-900"; + } else { + return is_assistant ? "bg-black" : "bg-sky-900"; + } + }, [colorMode, is_assistant]); + + return
{text}
; +}; diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 413a1e16..7cec2177 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -1,9 +1,11 @@ export enum TaskCategory { - Create, - Evaluate, + Create = "Create", + Evaluate = "Evaluate", + Label = "Label", } export const TaskTypes = [ + // create { label: "Create Initial Prompts", desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.", @@ -31,6 +33,7 @@ export const TaskTypes = [ overview: "Given the following conversation, provide an adequate reply", instruction: "Provide the assistant`s reply", }, + // evaluate { label: "Rank User Replies", category: TaskCategory.Evaluate, @@ -52,4 +55,12 @@ export const TaskTypes = [ pathname: "/evaluate/rank_initial_prompts", type: "rank_initial_prompts", }, + // label + { + label: "Label Initial Prompt", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_initial_prompt", + type: "label_initial_prompt", + }, ]; diff --git a/website/src/hooks/useLabelingTask.ts b/website/src/hooks/useLabelingTask.ts new file mode 100644 index 00000000..872909b7 --- /dev/null +++ b/website/src/hooks/useLabelingTask.ts @@ -0,0 +1,52 @@ +import { useEffect, useState } from "react"; +import fetcher from "src/lib/fetcher"; +import poster from "src/lib/poster"; +import useSWRImmutable from "swr/immutable"; +import useSWRMutation from "swr/mutation"; + +// TODO: type & centralize types for all tasks +interface TaskResponse { + id: string; + userId: string; + task: TaskType; +} + +export interface LabelInitialPromptTask { + id: string; + message_id: string; + prompt: string; + type: string; + valid_labels: string[]; +} + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelingTask = ({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => { + type ConcreteTaskResponse = TaskResponse; + + const [tasks, setTasks] = useState>([]); + + const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, { + onSuccess: (data: ConcreteTaskResponse) => { + setTasks([data]); + }, + }); + + useEffect(() => { + if (tasks.length === 0 && !isLoading && !error) { + mutate(); + } + }, [tasks, isLoading, mutate, error]); + + const { trigger } = useSWRMutation("/api/update_task", poster, { + onSuccess: async (reply) => { + const newTask: ConcreteTaskResponse = await reply.json(); + setTasks((oldTasks) => [...oldTasks, newTask]); + }, + }); + + const submit = (id: string, message_id: string, text: string, labels: Record) => + trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + + return { tasks, isLoading, submit, error, reset: mutate }; +}; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 4cf891e1..86854c21 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -42,7 +42,7 @@ export class OasstApiClient { } catch (e) { throw new OasstError(errorText, 0, resp.status); } - throw new OasstError(error.message, error.error_code, resp.status); + throw new OasstError(error.message ?? error, error.error_code, resp.status); } return await resp.json(); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 4eea8c1e..c8760324 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -35,7 +35,12 @@ const handler = async (req, res) => { }, }); - const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token); + let newTask; + try { + newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token); + } catch (err) { + return res.status(500).json(err); + } // Stores the new task with our database. const newRegisteredTask = await prisma.registeredTask.create({ diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx new file mode 100644 index 00000000..66ab0580 --- /dev/null +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -0,0 +1,113 @@ +import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; +import { useColorMode } from "@chakra-ui/react"; +import { useEffect, useId, useState } from "react"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { MessageView } from "src/components/Messages"; +import { TaskControls } from "src/components/Survey/TaskControls"; +import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; +import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask"; +import { colors } from "styles/Theme/colors"; + +const LabelInitialPrompt = () => { + const [sliderValues, setSliderValues] = useState([]); + + const { tasks, isLoading, submit, reset } = useLabelingTask({ + taskApiEndpoint: "label_initial_prompt", + }); + + const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => { + const labels = task.valid_labels.reduce((obj, label, i) => { + obj[label] = sliderValues[i].toString(); + return obj; + }, {} as Record); + + submit(id, task.message_id, task.prompt, labels); + }; + + const { colorMode } = useColorMode(); + const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; + + if (isLoading) { + return ; + } + + if (tasks.length === 0) { + return No tasks found...; + } + + const task = tasks[0].task; + + return ( +
+ + <> +
Label Initial Prompt
+

Provide labels for the following prompt

+ + + +
+ +
+ ); +}; + +export default LabelInitialPrompt; + +// TODO: consolidate with FlaggableElement + +interface CheckboxSliderGroupProps { + labelIDs: Array; + onChange: (sliderValues: number[]) => unknown; +} + +const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => { + const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0)); + + useEffect(() => { + onChange(sliderValues); + }, [sliderValues, onChange]); + + return ( + + {labelIDs.map((labelId, idx) => ( + { + const newState = sliderValues.slice(); + newState[idx] = sliderValue; + setSliderValues(newState); + }} + /> + ))} + + ); +}; + +function CheckboxSliderItem(props: { + labelId: string; + sliderValue: number; + sliderHandler: (newVal: number) => unknown; +}) { + const id = useId(); + const { colorMode } = useColorMode(); + + const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; + + return ( + <> + + props.sliderHandler(val / 100)}> + + + + + + + ); +}