mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
Merge pull request #520 from othrayte/skip-with-reason
Prompt for reason for skipping.
This commit is contained in:
@@ -1,9 +1,63 @@
|
||||
import { Button, ButtonProps } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
ButtonProps,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Textarea,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { FaChevronDown } from "react-icons/fa";
|
||||
|
||||
interface SkipButtonProps extends ButtonProps {
|
||||
onSkip: (reason: string) => void;
|
||||
}
|
||||
|
||||
export const SkipButton = ({ onSkip, ...props }: SkipButtonProps) => {
|
||||
const { isOpen, onOpen: showModal, onClose: closeModal } = useDisclosure();
|
||||
const [value, setValue] = useState("");
|
||||
|
||||
const onSubmit = () => {
|
||||
onSkip(value);
|
||||
setValue("");
|
||||
closeModal();
|
||||
};
|
||||
|
||||
export const SkipButton = ({ children, ...props }: ButtonProps) => {
|
||||
return (
|
||||
<Button size="lg" variant="outline" {...props}>
|
||||
{children}
|
||||
</Button>
|
||||
<>
|
||||
<Button size="lg" variant="outline" onClick={showModal} {...props}>
|
||||
Skip
|
||||
</Button>
|
||||
<Modal isOpen={isOpen} onClose={closeModal}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Skip</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Textarea
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
resize="none"
|
||||
placeholder="Any feedback on this task?"
|
||||
/>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button colorScheme="blue" mr={3} onClick={onSubmit}>
|
||||
Send
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -11,7 +11,8 @@ export interface TaskControlsProps {
|
||||
tasks: any[];
|
||||
className?: string;
|
||||
onSubmitResponse: (task: { id: string }) => void;
|
||||
onSkip: () => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
@@ -31,13 +32,17 @@ export const TaskControls = (props: TaskControlsProps) => {
|
||||
>
|
||||
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
<SkipButton
|
||||
onSkip={(reason: string) => {
|
||||
props.onSkipTask(props.tasks[0], reason);
|
||||
}}
|
||||
/>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton colorScheme="blue" data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onSkip}>
|
||||
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import { useState } from "react";
|
||||
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskType } from "./TaskTypes";
|
||||
|
||||
export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
|
||||
export interface CreateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
taskType: TaskType;
|
||||
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const [inputText, setInputText] = useState("");
|
||||
@@ -20,11 +32,6 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
@@ -48,7 +55,15 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={(task, reason) => {
|
||||
setInputText("");
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -5,7 +5,17 @@ import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverr
|
||||
|
||||
import { MessageTable } from "../Messages/MessageTable";
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
export interface EvaluateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
trigger: (update: { id: string; update_type: string; content: { ranking: number[] } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgClasses }: EvaluateTaskProps) => {
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
@@ -17,10 +27,6 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
let messages = null;
|
||||
if (tasks[0].task.conversation) {
|
||||
messages = tasks[0].task.conversation.messages;
|
||||
@@ -45,7 +51,11 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
isValid={ranking.length == tasks[0].task[sortables].length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
onSkipTask={(task, reason) => {
|
||||
setRanking([]);
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
import { CreateTask } from "./CreateTask";
|
||||
import { EvaluateTask } from "./EvaluateTask";
|
||||
import { TaskCategory, TaskTypes } from "./TaskTypes";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import poster from "src/lib/poster";
|
||||
|
||||
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
|
||||
onSuccess: async () => {
|
||||
mutate();
|
||||
},
|
||||
});
|
||||
|
||||
const rejectTask = (task: { id: string }, reason: string) => {
|
||||
sendRejection({
|
||||
id: task.id,
|
||||
reason,
|
||||
});
|
||||
};
|
||||
|
||||
function taskTypeComponent(type) {
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === type);
|
||||
const category = taskType.category;
|
||||
@@ -14,13 +29,22 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
<CreateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
mutate={mutate}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
taskType={taskType}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Evaluate:
|
||||
return <EvaluateTask tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />;
|
||||
return (
|
||||
<EvaluateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,17 @@ export enum TaskCategory {
|
||||
Label = "Label",
|
||||
}
|
||||
|
||||
export const TaskTypes = [
|
||||
export interface TaskType {
|
||||
label: string;
|
||||
desc: string;
|
||||
category: TaskCategory;
|
||||
pathname: string;
|
||||
type: string;
|
||||
overview?: string;
|
||||
instruction?: string;
|
||||
}
|
||||
|
||||
export const TaskTypes: TaskType[] = [
|
||||
// create
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
|
||||
@@ -68,6 +68,12 @@ export class OasstApiClient {
|
||||
});
|
||||
}
|
||||
|
||||
async nackTask(taskId: string, reason: string): Promise<void> {
|
||||
return this.post(`/api/v1/tasks/${taskId}/nack`, {
|
||||
reason,
|
||||
});
|
||||
}
|
||||
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to record interaction with task while fetching next task.
|
||||
// This is a raw Json type, so we can't use it to strongly type the task.
|
||||
|
||||
@@ -36,9 +36,6 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
// Update the backend with our Task ID
|
||||
await oasstApiClient.ackTask(task.id, registeredTask.id);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, reason } = await JSON.parse(req.body);
|
||||
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
|
||||
// Update the backend with the rejection
|
||||
await oasstApiClient.nackTask(id, reason);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json({});
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
@@ -6,9 +7,11 @@ import prisma from "src/lib/prismadb";
|
||||
* Stores the task interaction with the Task Backend and then returns the next task generated.
|
||||
*
|
||||
* This implicity does a few things:
|
||||
* 1) Stores the answer with the Task Backend.
|
||||
* 2) Records the new task in our local database.
|
||||
* 3) Returns the newly created task to the client.
|
||||
* 1) Records the users answer in our local database.
|
||||
* 2) Accepts the task.
|
||||
* 3) Sends the users answer to the Task Backend.
|
||||
* 4) Records the new task in our local database.
|
||||
* 5) Returns the newly created task to the client.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
@@ -20,7 +23,13 @@ const handler = async (req, res) => {
|
||||
}
|
||||
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id, content, update_type } = await JSON.parse(req.body);
|
||||
const { id: frontendId, content, update_type } = await JSON.parse(req.body);
|
||||
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
await oasstApiClient.ackTask(id, registeredTask.id);
|
||||
|
||||
// Log the interaction locally to create our user_post_id needed by the Task
|
||||
// Backend.
|
||||
@@ -29,7 +38,7 @@ const handler = async (req, res) => {
|
||||
content,
|
||||
task: {
|
||||
connect: {
|
||||
id,
|
||||
id: frontendId,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -37,7 +46,7 @@ const handler = async (req, res) => {
|
||||
|
||||
let newTask;
|
||||
try {
|
||||
newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
newTask = await oasstApiClient.interactTask(update_type, frontendId, interaction.id, content, token);
|
||||
} catch (err) {
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
@@ -87,7 +87,12 @@ const SummarizeStory = () => {
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -102,7 +102,12 @@ const RateSummary = () => {
|
||||
</section>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -33,7 +33,8 @@ const LabelAssistantReply = () => {
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,8 @@ const LabelInitialPrompt = () => {
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
|
||||
@@ -33,7 +33,8 @@ const LabelPrompterReply = () => {
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkip={reset}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user