Merge pull request #530 from LAION-AI/label_assistant_reply

Add Label Assistant Reply  Task
This commit is contained in:
Keith Stevens
2023-01-08 20:26:27 +09:00
committed by GitHub
9 changed files with 127 additions and 44 deletions
@@ -70,4 +70,11 @@ export const TaskTypes = [
pathname: "/label/label_prompter_reply",
type: "label_prompter_reply",
},
{
label: "Label Assistant Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_assistant_reply",
type: "label_assistant_reply",
},
];
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelAssistantReplyTask {
id: string;
type: LabelingTaskType.label_assistant_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
@@ -0,0 +1,15 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelInitialPromptTask {
id: string;
type: LabelingTaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () =>
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelPrompterReplyTask {
id: string;
type: LabelingTaskType.label_prompter_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () =>
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
@@ -1,17 +1,13 @@
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
import { useGenericTaskAPI } from "../useGenericTaskAPI";
export interface LabelInitialPromptTask {
id: string;
type: "label_initial_prompt";
message_id: string;
valid_labels: string[];
prompt: string;
export const enum LabelingTaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelInitialPromptTask>("label_initial_prompt");
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
@@ -1,31 +0,0 @@
import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI";
export interface LabelPrompterReplyTask {
id: string;
type: "label_prompter_reply";
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<LabelPrompterReplyTask>("label_prompter_reply");
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -0,0 +1,46 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelAssistantReplyTaskResponse,
useLabelAssistantReplyTask,
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
const LabelAssistantReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: true, message_id: task.message_id },
];
return (
<LabelTask
title="Label Assistant Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkip={reset}
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelAssistantReply;
@@ -3,7 +3,10 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelInitialPrompt";
import {
LabelInitialPromptTaskResponse,
useLabelInitialPromptTask,
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
@@ -4,7 +4,10 @@ import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply";
import {
LabelPrompterReplyTaskResponse,
useLabelPrompterReplyTask,
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);