Revert "Centralize task types"

This reverts commit 97c1f12e11.
This commit is contained in:
Keith Stevens
2023-01-10 16:24:35 +09:00
parent 555113a6f2
commit 062bfdba3a
28 changed files with 188 additions and 163 deletions
+6 -1
View File
@@ -1,10 +1,15 @@
import { Grid } from "@chakra-ui/react";
import { forwardRef, useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { Message } from "src/types/Conversation";
import { FlaggableElement } from "./FlaggableElement";
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
export interface ValidLabel {
name: string;
display_text: string;
+2 -3
View File
@@ -3,14 +3,13 @@ import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import {} from "src/components/Tasks/TaskTypes";
import { TaskType } from "./TaskTypes";
import { TaskType } from "src/components/Tasks/TaskTypes";
export interface CreateTaskProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
taskType: TaskInfo;
taskType: TaskType;
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
+2 -2
View File
@@ -4,7 +4,7 @@ export enum TaskCategory {
Label = "Label",
}
export interface TaskInfo {
export interface TaskType {
label: string;
desc: string;
category: TaskCategory;
@@ -14,7 +14,7 @@ export interface TaskInfo {
instruction?: string;
}
export const TaskTypes: TaskInfo[] = [
export const TaskTypes: TaskType[] = [
// create
{
label: "Create Initial Prompts",
@@ -0,0 +1,9 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface CreateInitialPromptTask {
id: string;
type: "initial_prompt";
hint: string;
}
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>("initial_prompt");
@@ -0,0 +1,24 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseCreateReplyTask {
id: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
type: "assistant_reply";
}
export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
type: "prompter_reply";
}
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>("assistant_reply");
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>("prompter_reply");
@@ -0,0 +1,9 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface RankInitialPromptsTask {
id: string;
type: "rank_initial_prompts";
prompts: string[];
}
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>("rank_initial_prompts");
@@ -0,0 +1,25 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseRankRepliesTask {
id: string;
replies: string[];
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
interface RankAssistantRepliesTask extends BaseRankRepliesTask {
type: "rank_assistant_replies";
}
interface RankPrompterRepliesTask extends BaseRankRepliesTask {
type: "rank_prompter_replies";
}
export const useRankAssistantRepliesTask = () => useGenericTaskAPI<RankAssistantRepliesTask>("rank_assistant_replies");
export const useRankPrompterRepliesTask = () => useGenericTaskAPI<RankPrompterRepliesTask>("rank_prompter_replies");
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelAssistantReplyTask {
id: string;
type: LabelingTaskType.label_assistant_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
@@ -0,0 +1,15 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelInitialPromptTask {
id: string;
type: LabelingTaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () =>
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelPrompterReplyTask {
id: string;
type: LabelingTaskType.label_prompter_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () =>
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
@@ -0,0 +1,20 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
export const enum LabelingTaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -1,8 +0,0 @@
import { TaskType } from "src/types/Task";
import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>(TaskType.assistant_reply);
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>(TaskType.prompter_reply);
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>(TaskType.initial_prompt);
+10 -2
View File
@@ -2,11 +2,19 @@ import { useState } from "react";
import type { ValidLabel } from "src/components/Messages";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { BaseTask, TaskResponse } from "src/types/Task";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskApiEndpoint: string) => {
// TODO: type & centralize types for all tasks
export interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
valid_labels: ValidLabel[];
}
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
@@ -1,32 +0,0 @@
import { BaseTask, TaskResponse, TaskType } from "src/types/Task";
import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
const useLabelingTask = <Task extends BaseTask>(
endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt
) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<Task>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(TaskType.label_assistant_reply);
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () => useLabelingTask<LabelInitialPromptTask>(TaskType.label_initial_prompt);
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () => useLabelingTask<LabelPrompterReplyTask>(TaskType.label_prompter_reply);
-12
View File
@@ -1,12 +0,0 @@
import { TaskType } from "src/types/Task";
import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks";
import { useGenericTaskAPI } from "./useGenericTaskAPI";
export const useRankAssistantRepliesTask = () =>
useGenericTaskAPI<RankAssistantRepliesTask>(TaskType.rank_assistant_replies);
export const useRankPrompterRepliesTask = () =>
useGenericTaskAPI<RankPrompterRepliesTask>(TaskType.rank_prompter_replies);
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>(TaskType.rank_initial_prompts);
+1 -1
View File
@@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
const AssistantReply = () => {
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
+1 -1
View File
@@ -3,7 +3,7 @@ import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
const InitialPrompt = () => {
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
+1 -1
View File
@@ -3,7 +3,7 @@ import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
const UserReply = () => {
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
@@ -3,7 +3,7 @@ import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankAssistantReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
@@ -3,7 +3,7 @@ import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
const RankInitialPrompts = () => {
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
@@ -3,7 +3,7 @@ import Head from "next/head";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankUserReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
@@ -1,10 +1,13 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
import { Message } from "src/types/Conversation";
import {
LabelAssistantReplyTaskResponse,
useLabelAssistantReplyTask,
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
const LabelAssistantReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
@@ -3,7 +3,10 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
import {
LabelInitialPromptTaskResponse,
useLabelInitialPromptTask,
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
@@ -1,10 +1,13 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
import { Message } from "src/types/Conversation";
import {
LabelPrompterReplyTaskResponse,
useLabelPrompterReplyTask,
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
+2 -2
View File
@@ -2,9 +2,9 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
import Head from "next/head";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import { Message } from "src/types/Conversation";
import useSWRImmutable from "swr/immutable";
const MessagesDashboard = () => {
@@ -82,6 +82,6 @@ const MessagesDashboard = () => {
);
};
MessagesDashboard.getLayout = getDashboardLayout;
MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
export default MessagesDashboard;
-9
View File
@@ -1,9 +0,0 @@
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
export interface Conversation {
messages: Message[];
}
-24
View File
@@ -1,24 +0,0 @@
export const enum TaskType {
initial_prompt = "initial_prompt",
assistant_reply = "assistant_reply",
prompter_reply = "prompter_reply",
rank_initial_prompts = "rank_initial_prompts",
rank_assistant_replies = "rank_assistant_replies",
rank_prompter_replies = "rank_prompter_replies",
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export interface BaseTask {
id: string;
type: TaskType;
}
export interface TaskResponse<Task extends BaseTask> {
id: string;
userId: string;
task: Task;
}
-57
View File
@@ -1,57 +0,0 @@
import { Conversation } from "./Conversation";
import { BaseTask, TaskType } from "./Task";
export interface CreateInitialPromptTask extends BaseTask {
type: TaskType.initial_prompt;
hint: string;
}
export interface CreateAssistantReplyTask extends BaseTask {
type: TaskType.assistant_reply;
conversation: Conversation;
}
export interface CreatePrompterReplyTask extends BaseTask {
type: TaskType.prompter_reply;
conversation: Conversation;
}
export interface RankInitialPromptsTask extends BaseTask {
type: TaskType.rank_initial_prompts;
prompts: string[];
}
export interface RankAssistantRepliesTask extends BaseTask {
type: TaskType.rank_assistant_replies;
conversation: Conversation;
replies: string[];
}
export interface RankPrompterRepliesTask extends BaseTask {
type: TaskType.rank_prompter_replies;
conversation: Conversation;
replies: string[];
}
export interface LabelAssistantReplyTask extends BaseTask {
type: TaskType.label_assistant_reply;
message_id: string;
conversation: Conversation;
reply: string;
valid_labels: string[];
}
export interface LabelInitialPromptTask extends BaseTask {
type: TaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export interface LabelPrompterReplyTask extends BaseTask {
type: TaskType.label_prompter_reply;
message_id: string;
conversation: Conversation;
reply: string;
valid_labels: string[];
}