Merge pull request #518 from melvinebenezer/371_set_labels

#371 set labels
This commit is contained in:
Keith Stevens
2023-01-08 15:06:47 +09:00
committed by GitHub
4 changed files with 83 additions and 36 deletions
+37 -32
View File
@@ -29,7 +29,7 @@ import useSWRMutation from "swr/mutation";
export const FlaggableElement = (props) => {
const [isEditing, setIsEditing] = useBoolean();
const { trigger } = useSWRMutation("/api/v1/text_labels", poster, {
const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off;
},
@@ -42,7 +42,12 @@ export const FlaggableElement = (props) => {
label_map.set(flag.attributeName, sliderValues[i]);
}
});
trigger({ post_id: props.post_id, label_map: Object.fromEntries(label_map), text: props.text });
trigger({
message_id: props.message_id,
post_id: props.post_id,
label_map: Object.fromEntries(label_map),
text: props.text,
});
};
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
@@ -184,40 +189,40 @@ interface textFlagLabels {
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
// For the time being this list is configured on the FE.
// In the future it may be provided by the API.
// {
// attributeName: "fails_task",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation: "__TODO__",
// },
// {
// attributeName: "not_customer_assistant_appropriate",
// labelText: "Inappropriate for customer assistant",
// additionalExplanation: "__TODO__",
// },
{
attributeName: "fails_task",
labelText: "Fails to follow the correct instruction / task",
additionalExplanation: "__TODO__",
},
{
attributeName: "not_customer_assistant_appropriate",
labelText: "Inappropriate for customer assistant",
additionalExplanation: "__TODO__",
},
{
attributeName: "contains_sexual_content",
attributeName: "sexual_content",
labelText: "Contains sexual content",
},
{
attributeName: "contains_violent_content",
attributeName: "violence",
labelText: "Contains violent content",
},
{
attributeName: "encourages_violence",
labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
},
{
attributeName: "denigrates_a_protected_class",
labelText: "Denigrates a protected class",
},
{
attributeName: "gives_harmful_advice",
labelText: "Fails to follow the correct instruction / task",
additionalExplanation:
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
},
{
attributeName: "expresses_moral_judgement",
labelText: "Expresses moral judgement",
},
// {
// attributeName: "encourages_violence",
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
// },
// {
// attributeName: "denigrates_a_protected_class",
// labelText: "Denigrates a protected class",
// },
// {
// attributeName: "gives_harmful_advice",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation:
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
// },
// {
// attributeName: "expresses_moral_judgement",
// labelText: "Expresses moral judgement",
// },
];
+4 -3
View File
@@ -7,14 +7,15 @@ import { FlaggableElement } from "./FlaggableElement";
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id } = messageProps;
const { text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} key={i + text}>
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<MessageView {...messageProps} />
</FlaggableElement>
);
@@ -23,7 +24,7 @@ export const Messages = ({ messages, post_id }: { messages: Message[]; post_id:
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = ({ is_assistant, text }: Message) => {
export const MessageView = ({ is_assistant, text, message_id }: Message) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
+41
View File
@@ -0,0 +1,41 @@
import { getToken } from "next-auth/jwt";
import prisma from "src/lib/prismadb";
/**
* Sets the Label in the Backend.
*
*/
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
// Parse out the local message_id, task ID and the interaction contents.
const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
type: "text_labels",
message_id: message_id,
labels: label_map,
text: text,
user: {
id: token.sub,
display_name: token.name || token.email,
auth_method: "local",
},
}),
});
res.status(interactionRes.status).end();
};
export default handler;
@@ -43,7 +43,7 @@ const LabelInitialPrompt = () => {
<>
<h5 className="text-lg font-semibold">Label Initial Prompt</h5>
<p className="text-lg py-1">Provide labels for the following prompt</p>
<MessageView text={task.prompt} is_assistant />
<MessageView text={task.prompt} is_assistant message_id={task.message_id} />
</>
<CheckboxSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />
</TwoColumnsWithCards>