diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index c4fe5c58..63370444 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -29,7 +29,7 @@ import useSWRMutation from "swr/mutation"; export const FlaggableElement = (props) => { const [isEditing, setIsEditing] = useBoolean(); - const { trigger } = useSWRMutation("/api/v1/text_labels", poster, { + const { trigger } = useSWRMutation("/api/set_label", poster, { onSuccess: () => { setIsEditing.off; }, @@ -42,7 +42,12 @@ export const FlaggableElement = (props) => { label_map.set(flag.attributeName, sliderValues[i]); } }); - trigger({ post_id: props.post_id, label_map: Object.fromEntries(label_map), text: props.text }); + trigger({ + message_id: props.message_id, + post_id: props.post_id, + label_map: Object.fromEntries(label_map), + text: props.text, + }); }; const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false)); const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1)); @@ -184,40 +189,40 @@ interface textFlagLabels { const TEXT_LABEL_FLAGS: textFlagLabels[] = [ // For the time being this list is configured on the FE. // In the future it may be provided by the API. + // { + // attributeName: "fails_task", + // labelText: "Fails to follow the correct instruction / task", + // additionalExplanation: "__TODO__", + // }, + // { + // attributeName: "not_customer_assistant_appropriate", + // labelText: "Inappropriate for customer assistant", + // additionalExplanation: "__TODO__", + // }, { - attributeName: "fails_task", - labelText: "Fails to follow the correct instruction / task", - additionalExplanation: "__TODO__", - }, - { - attributeName: "not_customer_assistant_appropriate", - labelText: "Inappropriate for customer assistant", - additionalExplanation: "__TODO__", - }, - { - attributeName: "contains_sexual_content", + attributeName: "sexual_content", labelText: "Contains sexual content", }, { - attributeName: "contains_violent_content", + attributeName: "violence", labelText: "Contains violent content", }, - { - attributeName: "encourages_violence", - labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm", - }, - { - attributeName: "denigrates_a_protected_class", - labelText: "Denigrates a protected class", - }, - { - attributeName: "gives_harmful_advice", - labelText: "Fails to follow the correct instruction / task", - additionalExplanation: - "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.", - }, - { - attributeName: "expresses_moral_judgement", - labelText: "Expresses moral judgement", - }, + // { + // attributeName: "encourages_violence", + // labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm", + // }, + // { + // attributeName: "denigrates_a_protected_class", + // labelText: "Denigrates a protected class", + // }, + // { + // attributeName: "gives_harmful_advice", + // labelText: "Fails to follow the correct instruction / task", + // additionalExplanation: + // "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.", + // }, + // { + // attributeName: "expresses_moral_judgement", + // labelText: "Expresses moral judgement", + // }, ]; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 7b69bc50..226c6154 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -7,14 +7,15 @@ import { FlaggableElement } from "./FlaggableElement"; export interface Message { text: string; is_assistant: boolean; + message_id: string; } export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { const items = messages.map((messageProps: Message, i: number) => { + const { message_id } = messageProps; const { text } = messageProps; - return ( - + ); @@ -23,7 +24,7 @@ export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: return {items}; }; -export const MessageView = ({ is_assistant, text }: Message) => { +export const MessageView = ({ is_assistant, text, message_id }: Message) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts new file mode 100644 index 00000000..4db5ddaf --- /dev/null +++ b/website/src/pages/api/set_label.ts @@ -0,0 +1,41 @@ +import { getToken } from "next-auth/jwt"; +import prisma from "src/lib/prismadb"; + +/** + * Sets the Label in the Backend. + * + */ +const handler = async (req, res) => { + const token = await getToken({ req }); + + // Return nothing if the user isn't registered. + if (!token) { + res.status(401).end(); + return; + } + + // Parse out the local message_id, task ID and the interaction contents. + const { message_id, post_id, label_map, text } = await JSON.parse(req.body); + + const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, { + method: "POST", + headers: { + "X-API-Key": process.env.FASTAPI_KEY, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + type: "text_labels", + message_id: message_id, + labels: label_map, + text: text, + user: { + id: token.sub, + display_name: token.name || token.email, + auth_method: "local", + }, + }), + }); + res.status(interactionRes.status).end(); +}; + +export default handler; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 0c3b47be..e400e8fd 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -43,7 +43,7 @@ const LabelInitialPrompt = () => { <>
Label Initial Prompt

Provide labels for the following prompt

- +