mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merging with main
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
.venv
|
||||
venv
|
||||
.env
|
||||
*.pyc
|
||||
*.swp
|
||||
|
||||
@@ -110,3 +110,8 @@ Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
|
||||
### Contribute a Dataset
|
||||
|
||||
See
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
|
||||
|
||||
@@ -7,6 +7,9 @@ In root directory, run
|
||||
database. The default settings are already configured to connect to the database
|
||||
at `localhost:5432`.
|
||||
|
||||
Python 3.10 is required. It is recommended to use `pyenv` which will recognise
|
||||
the `.python-version` in the project root directory.
|
||||
|
||||
Make sure you have all requirements installed. You can do this by running
|
||||
`pip install -r requirements.txt` inside the `backend` folder and
|
||||
`pip install -e .` inside the `oasst-shared` folder. Then, run the backend using
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -6,6 +7,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -13,6 +15,22 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}", response_model=protocol.User)
|
||||
def query_frontend_user(
|
||||
auth_method: str,
|
||||
username: str,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query frontend user.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
user = ur.query_frontend_user(auth_method, username, api_client_id)
|
||||
return protocol.User(id=user.username, display_name=user.display_name, auth_method=user.auth_method)
|
||||
|
||||
|
||||
@router.get("/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class UserRepository:
|
||||
@@ -11,6 +14,27 @@ class UserRepository:
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def query_frontend_user(
|
||||
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
|
||||
) -> Optional[User]:
|
||||
if not api_client_id:
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged API client asks for foreign user
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ class OasstErrorCode(IntEnum):
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
2000-3000: prompt_repository, task_repository, user_repository
|
||||
3000-4000: external resources
|
||||
"""
|
||||
|
||||
@@ -45,6 +45,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
HUGGINGFACE_API_ERROR = 3001
|
||||
|
||||
@@ -4,21 +4,22 @@ import clsx from "clsx";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TaskStatus } from "src/components/Tasks/Task";
|
||||
|
||||
export interface TaskControlsProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
task: any;
|
||||
className?: string;
|
||||
onSubmitResponse: (task: { id: string }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
taskStatus: TaskStatus;
|
||||
onSubmit: () => void;
|
||||
onSkip: (reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const isLightMode = colorMode === "light";
|
||||
const endTask = props.tasks[props.tasks.length - 1];
|
||||
return (
|
||||
<section
|
||||
className={clsx(
|
||||
@@ -30,15 +31,16 @@ export const TaskControls = (props: TaskControlsProps) => {
|
||||
}
|
||||
)}
|
||||
>
|
||||
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
|
||||
<TaskInfo id={props.task.id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton
|
||||
onSkip={(reason: string) => {
|
||||
props.onSkipTask(props.tasks[0], reason);
|
||||
}}
|
||||
/>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton colorScheme="blue" data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
|
||||
<SkipButton onSkip={props.onSkip} disabled={props.taskStatus === "SUBMITTED"} />
|
||||
{props.taskStatus !== "SUBMITTED" ? (
|
||||
<SubmitButton
|
||||
colorScheme="blue"
|
||||
data-cy="submit"
|
||||
disabled={props.taskStatus === "NOT_SUBMITTABLE"}
|
||||
onClick={props.onSubmit}
|
||||
>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { TaskControls, TaskControlsProps } from "src/components/Survey/TaskControls";
|
||||
|
||||
interface TaskControlsOverridableProps extends TaskControlsProps {
|
||||
isValid: boolean;
|
||||
prepareForSubmit: () => void;
|
||||
}
|
||||
|
||||
export const TaskControlsOverridable = (props: TaskControlsOverridableProps) => {
|
||||
const { isValid, onSubmitResponse, ...rest } = props;
|
||||
const { isOpen: isModalOpen, onOpen: onOpenModal, onClose: onModalClose } = useDisclosure();
|
||||
|
||||
const unchangedResponsePrompt = () => {
|
||||
onOpenModal();
|
||||
|
||||
// Ideally this happens when the user clicks submit, but we can't
|
||||
// reliably wait for it to be executed before submitting the response
|
||||
// without significant refactoring.
|
||||
// As a result, modal will only display once even if the user doesn't proceed
|
||||
props.prepareForSubmit();
|
||||
};
|
||||
|
||||
const onSubmitResponseOverride = () => {
|
||||
onSubmitResponse(props.tasks[0]);
|
||||
onModalClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal isOpen={isModalOpen} onClose={onModalClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalCloseButton />
|
||||
<ModalHeader>Order Unchanged</ModalHeader>
|
||||
<ModalBody>You have not changed the order of the prompts. Are you sure you would like to submit?</ModalBody>
|
||||
<ModalFooter>
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<Button variant={"ghost"} onClick={onModalClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={onSubmitResponseOverride}>Submit anyway</Button>
|
||||
</Flex>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
<TaskControls onSubmitResponse={isValid ? props.onSubmitResponse : unchangedResponsePrompt} {...rest} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,67 +1,34 @@
|
||||
import { useState } from "react";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskInfo } from "src/components/Tasks/TaskTypes";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
|
||||
export interface CreateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
taskType: TaskInfo;
|
||||
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
export const CreateTask = ({ task, taskType, onReplyChanged }: TaskSurveyProps<{ text: string }>) => {
|
||||
const [inputText, setInputText] = useState("");
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
const text = event.target.value;
|
||||
onReplyChanged({ content: { text }, state: "VALID" });
|
||||
setInputText(text);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Reply..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={(task, reason) => {
|
||||
setInputText("");
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
|
||||
</>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
|
||||
<TrackedTextarea
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: "Reply..." }}
|
||||
/>
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,62 +1,39 @@
|
||||
import { useState } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
|
||||
import { MessageTable } from "../Messages/MessageTable";
|
||||
|
||||
export interface EvaluateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
trigger: (update: { id: string; update_type: string; content: { ranking: number[] } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgClasses }: EvaluateTaskProps) => {
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "message_ranking",
|
||||
content: {
|
||||
ranking,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
let messages = null;
|
||||
if (tasks[0].task.conversation) {
|
||||
messages = tasks[0].task.conversation.messages;
|
||||
export const EvaluateTask = ({ task, onReplyChanged }: TaskSurveyProps<{ ranking: number[] }>) => {
|
||||
let messages = [];
|
||||
if (task.conversation) {
|
||||
messages = task.conversation.messages;
|
||||
messages = messages.map((message, index) => ({ ...message, id: index }));
|
||||
}
|
||||
|
||||
const sortables = tasks[0].task.replies ? "replies" : "prompts";
|
||||
useEffect(() => {
|
||||
const conversationMsgs = task.conversation ? task.conversation.messages : [];
|
||||
const defaultRanking = conversationMsgs.map((message, index) => index);
|
||||
onReplyChanged({
|
||||
content: { ranking: defaultRanking },
|
||||
state: "DEFAULT",
|
||||
});
|
||||
}, [task.conversation, onReplyChanged]);
|
||||
|
||||
const onRank = (newRanking: number[]) => {
|
||||
onReplyChanged({ content: { ranking: newRanking }, state: "VALID" });
|
||||
};
|
||||
|
||||
const sortables = task.replies ? "replies" : "prompts";
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
{messages ? <MessageTable messages={messages} /> : null}
|
||||
<Sortable items={tasks[0].task[sortables]} onChange={setRanking} className="my-8" />
|
||||
</SurveyCard>
|
||||
|
||||
<TaskControlsOverridable
|
||||
tasks={tasks}
|
||||
isValid={ranking.length === tasks[0].task[sortables].length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={(task, reason) => {
|
||||
setRanking([]);
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
<SurveyCard className="max-w-7xl mx-auto h-fit mb-24">
|
||||
<h5 className="text-lg font-semibold mb-4">Instructions</h5>
|
||||
<p className="text-lg py-1">
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</p>
|
||||
<MessageTable messages={messages} />
|
||||
<Sortable items={task[sortables]} onChange={onRank} className="my-8" />
|
||||
</SurveyCard>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,44 +1,56 @@
|
||||
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
|
||||
import { useEffect, useId, useState } from "react";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export const LabelTask = ({
|
||||
title,
|
||||
desc,
|
||||
messages,
|
||||
inputs,
|
||||
controls,
|
||||
}: {
|
||||
title: string;
|
||||
desc: string;
|
||||
messages: ReactNode;
|
||||
inputs: ReactNode;
|
||||
controls: ReactNode;
|
||||
}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
task,
|
||||
taskType,
|
||||
onReplyChanged,
|
||||
}: TaskSurveyProps<{ text: string; labels: { [k: string]: number }; message_id: string }>) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const card = useMemo(
|
||||
() => (
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{title}</h5>
|
||||
<p className="text-lg py-1">{desc}</p>
|
||||
{messages}
|
||||
</>
|
||||
),
|
||||
[title, desc, messages]
|
||||
);
|
||||
const valid_labels = task.valid_labels;
|
||||
|
||||
useEffect(() => {
|
||||
onReplyChanged({ content: { labels: {}, text: task.reply, message_id: task.message_id }, state: "DEFAULT" });
|
||||
}, [task.reply, task.message_id, onReplyChanged]);
|
||||
|
||||
const onSliderChange = (values: number[]) => {
|
||||
console.assert(valid_labels.length === sliderValues.length);
|
||||
const labels = Object.fromEntries(valid_labels.map((label, i) => [label, sliderValues[i]]));
|
||||
onReplyChanged({ content: { labels, text: task.reply, message_id: task.message_id }, state: "VALID" });
|
||||
setSliderValues(values);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
{card}
|
||||
{inputs}
|
||||
</TwoColumnsWithCards>
|
||||
{controls}
|
||||
</div>
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{taskType.label}</h5>
|
||||
<p className="text-lg py-1">{taskType.overview}</p>
|
||||
|
||||
{task.conversation ? (
|
||||
<MessageTable
|
||||
messages={[
|
||||
...(task.conversation ? task.conversation.messages : []),
|
||||
{
|
||||
text: task.reply,
|
||||
is_assistant: task.type === TaskType.label_assistant_reply,
|
||||
message_id: task.message_id,
|
||||
},
|
||||
]}
|
||||
/>
|
||||
) : (
|
||||
<MessageView text={task.prompt} is_assistant={false} message_id={task.message_id} />
|
||||
)}
|
||||
</>
|
||||
<LabelSliderGroup labelIDs={task.valid_labels} onChange={onSliderChange} />
|
||||
</TwoColumnsWithCards>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -51,10 +63,6 @@ interface LabelSliderGroupProps {
|
||||
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
@@ -65,6 +73,7 @@ export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps)
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
onChange(sliderValues);
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -1,11 +1,35 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { CreateTask } from "src/components/Tasks/CreateTask";
|
||||
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
|
||||
import { TaskCategory, TaskTypes } from "src/components/Tasks/TaskTypes";
|
||||
import { LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { TaskCategory, TaskInfo, TaskTypes } from "src/components/Tasks/TaskTypes";
|
||||
import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning";
|
||||
import poster from "src/lib/poster";
|
||||
import { TaskContent } from "src/types/Task";
|
||||
import { TaskReplyState } from "src/types/TaskReplyState";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
const task = tasks[0].task;
|
||||
export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "SUBMITABLE" | "SUBMITTED";
|
||||
|
||||
export interface TaskSurveyProps<T> {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
task: any;
|
||||
taskType: TaskInfo;
|
||||
onReplyChanged: (state: TaskReplyState<T>) => void;
|
||||
}
|
||||
|
||||
export const Task = ({ task, trigger, mutate }) => {
|
||||
const [taskStatus, setTaskStatus] = useState<TaskStatus>("NOT_SUBMITTABLE");
|
||||
const replyContent = useRef<TaskContent>(null);
|
||||
const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
|
||||
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === task.type);
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
|
||||
onSuccess: async () => {
|
||||
@@ -13,40 +37,85 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
},
|
||||
});
|
||||
|
||||
const rejectTask = (task: { id: string }, reason: string) => {
|
||||
const rejectTask = (reason: string) => {
|
||||
sendRejection({
|
||||
id: task.id,
|
||||
reason,
|
||||
});
|
||||
};
|
||||
|
||||
function taskTypeComponent(type) {
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === type);
|
||||
const category = taskType.category;
|
||||
switch (category) {
|
||||
const onReplyChanged = useRef((state: TaskReplyState<TaskContent>) => {
|
||||
if (taskStatus === "SUBMITTED") return;
|
||||
|
||||
replyContent.current = state?.content;
|
||||
if (state === null) {
|
||||
if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE");
|
||||
} else if (state.state === "DEFAULT") {
|
||||
if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT");
|
||||
} else if (state.state === "VALID") {
|
||||
if (taskStatus !== "SUBMITABLE") setTaskStatus("SUBMITABLE");
|
||||
}
|
||||
}).current;
|
||||
|
||||
const submitResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "NOT_SUBMITTABLE":
|
||||
return;
|
||||
case "DEFAULT":
|
||||
setShowUnchangedWarning(true);
|
||||
break;
|
||||
case "SUBMITABLE": {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
setTaskStatus("SUBMITTED");
|
||||
break;
|
||||
}
|
||||
case "SUBMITTED":
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
function taskTypeComponent() {
|
||||
switch (taskType.category) {
|
||||
case TaskCategory.Create:
|
||||
return (
|
||||
<CreateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
taskType={taskType}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
return <CreateTask key={task.id} task={task} taskType={taskType} onReplyChanged={onReplyChanged} />;
|
||||
case TaskCategory.Evaluate:
|
||||
return (
|
||||
<EvaluateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
return <EvaluateTask key={task.id} task={task} taskType={taskType} onReplyChanged={onReplyChanged} />;
|
||||
case TaskCategory.Label:
|
||||
return <LabelTask key={task.id} task={task} taskType={taskType} onReplyChanged={onReplyChanged} />;
|
||||
}
|
||||
}
|
||||
|
||||
return taskTypeComponent(task.type);
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
{taskTypeComponent()}
|
||||
<TaskControls
|
||||
task={task}
|
||||
taskStatus={taskStatus}
|
||||
onSubmit={submitResponse}
|
||||
onSkip={rejectTask}
|
||||
onNextTask={mutate}
|
||||
/>
|
||||
<UnchangedWarning
|
||||
show={showUnchangedWarning}
|
||||
title={taskType.unchanged_title || "No changes"}
|
||||
message={taskType.unchanged_message || "Are you sure you would like to submit?"}
|
||||
onClose={() => setShowUnchangedWarning(false)}
|
||||
onSubmitAnyway={() => {
|
||||
if (taskStatus === "DEFAULT") {
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
setTaskStatus("SUBMITTED");
|
||||
setShowUnchangedWarning(false);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -12,6 +12,9 @@ export interface TaskInfo {
|
||||
type: string;
|
||||
overview?: string;
|
||||
instruction?: string;
|
||||
update_type: string;
|
||||
unchanged_title?: string;
|
||||
unchanged_message?: string;
|
||||
}
|
||||
|
||||
export const TaskTypes: TaskInfo[] = [
|
||||
@@ -24,6 +27,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
type: "initial_prompt",
|
||||
overview: "Create an initial message to send to the assistant",
|
||||
instruction: "Provide the initial prompt",
|
||||
update_type: "text_reply_to_message",
|
||||
},
|
||||
{
|
||||
label: "Reply as User",
|
||||
@@ -33,6 +37,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
type: "prompter_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the user`s reply",
|
||||
update_type: "text_reply_to_message",
|
||||
},
|
||||
{
|
||||
label: "Reply as Assistant",
|
||||
@@ -42,6 +47,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
type: "assistant_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the assistant`s reply",
|
||||
update_type: "text_reply_to_message",
|
||||
},
|
||||
// evaluate
|
||||
{
|
||||
@@ -50,6 +56,9 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
pathname: "/evaluate/rank_user_replies",
|
||||
type: "rank_prompter_replies",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
unchanged_message: "You have not changed the order of the prompts. Are you sure you would like to submit?",
|
||||
},
|
||||
{
|
||||
label: "Rank Assistant Replies",
|
||||
@@ -57,6 +66,9 @@ export const TaskTypes: TaskInfo[] = [
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_assistant_replies",
|
||||
type: "rank_assistant_replies",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
unchanged_message: "You have not changed the order of the prompts. Are you sure you would like to submit?",
|
||||
},
|
||||
{
|
||||
label: "Rank Initial Prompts",
|
||||
@@ -64,6 +76,9 @@ export const TaskTypes: TaskInfo[] = [
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
type: "rank_initial_prompts",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
unchanged_message: "You have not changed the order of the prompts. Are you sure you would like to submit?",
|
||||
},
|
||||
// label
|
||||
{
|
||||
@@ -71,20 +86,26 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
overview: "Provide labels for the following prompt",
|
||||
type: "label_initial_prompt",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
label: "Label Prompter Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
overview: "Given the following discussion, provide labels for the final promp",
|
||||
type: "label_prompter_reply",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
label: "Label Assistant Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt.",
|
||||
type: "label_assistant_reply",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
} from "@chakra-ui/react";
|
||||
|
||||
interface UnchangedWarningProps {
|
||||
show: boolean;
|
||||
title: string;
|
||||
message: string;
|
||||
onClose: () => void;
|
||||
onSubmitAnyway: () => void;
|
||||
}
|
||||
|
||||
export const UnchangedWarning = (props: UnchangedWarningProps) => {
|
||||
return (
|
||||
<>
|
||||
<Modal isOpen={props.show} onClose={props.onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalCloseButton />
|
||||
<ModalHeader>{props.title}</ModalHeader>
|
||||
<ModalBody>{props.message}</ModalBody>
|
||||
<ModalFooter>
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<Button variant={"ghost"} onClick={props.onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={props.onSubmitAnyway}>Submit anyway</Button>
|
||||
</Flex>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,32 +1,9 @@
|
||||
import { BaseTask, TaskResponse, TaskType } from "src/types/Task";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { LabelAssistantReplyTask, LabelInitialPromptTask, LabelPrompterReplyTask } from "src/types/Tasks";
|
||||
|
||||
import { useGenericTaskAPI } from "./useGenericTaskAPI";
|
||||
|
||||
const useLabelingTask = <Task extends BaseTask>(
|
||||
endpoint: TaskType.label_assistant_reply | TaskType.label_prompter_reply | TaskType.label_initial_prompt
|
||||
) => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<Task>(endpoint);
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
|
||||
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
|
||||
|
||||
export const useLabelAssistantReplyTask = () =>
|
||||
useLabelingTask<LabelAssistantReplyTask>(TaskType.label_assistant_reply);
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelInitialPromptTask = () => useLabelingTask<LabelInitialPromptTask>(TaskType.label_initial_prompt);
|
||||
|
||||
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
|
||||
|
||||
export const useLabelPrompterReplyTask = () => useLabelingTask<LabelPrompterReplyTask>(TaskType.label_prompter_reply);
|
||||
useGenericTaskAPI<LabelAssistantReplyTask>(TaskType.label_assistant_reply);
|
||||
export const useLabelInitialPromptTask = () => useGenericTaskAPI<LabelInitialPromptTask>(TaskType.label_initial_prompt);
|
||||
export const useLabelPrompterReplyTask = () => useGenericTaskAPI<LabelPrompterReplyTask>(TaskType.label_prompter_reply);
|
||||
|
||||
+17
-2
@@ -1,5 +1,20 @@
|
||||
import type { NextApiRequest, NextApiResponse } from "next";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { getToken, JWT } from "next-auth/jwt";
|
||||
|
||||
/**
|
||||
* Wraps any API Route handler and verifies that the user does not have the
|
||||
* specified role. Returns a 403 if they do, otherwise runs the handler.
|
||||
*/
|
||||
const withoutRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse, arg2: JWT) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role === role) {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
return handler(req, res, token);
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
@@ -16,4 +31,4 @@ const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiRes
|
||||
};
|
||||
};
|
||||
|
||||
export default withRole;
|
||||
export { withoutRole, withRole };
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import withRole from "src/lib/auth";
|
||||
import { withRole } from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import withRole from "src/lib/auth";
|
||||
import { withRole } from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
// The number of users to fetch in any request.
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/children`, {
|
||||
@@ -22,6 +14,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}/conversation`, {
|
||||
@@ -22,6 +14,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const { id } = req.query;
|
||||
|
||||
const messageRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages/${id}`, {
|
||||
@@ -22,6 +14,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(message);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const { id } = req.query;
|
||||
|
||||
if (!id) {
|
||||
@@ -43,6 +35,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(parent);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
@@ -19,6 +11,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
//TODO: add params if needed
|
||||
const params = new URLSearchParams({
|
||||
username: token.sub,
|
||||
@@ -24,6 +16,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
@@ -10,19 +10,10 @@ import prisma from "src/lib/prismadb";
|
||||
* 3) Send and Ack to the Task Backend with our local id for the task.
|
||||
* 4) Return everything to the client.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const { task_type } = req.query;
|
||||
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Fetch the new task.
|
||||
const task = await oasstApiClient.fetchTask(task_type, token);
|
||||
const { task_type } = req.query;
|
||||
const task = await oasstApiClient.fetchTask(task_type as string, token);
|
||||
|
||||
// Store the task and link it to the user..
|
||||
const registeredTask = await prisma.registeredTask.create({
|
||||
@@ -38,6 +29,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, reason } = await JSON.parse(req.body);
|
||||
|
||||
@@ -25,6 +17,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json({});
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
|
||||
/**
|
||||
* Sets the Label in the Backend.
|
||||
*
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local message_id, task ID and the interaction contents.
|
||||
const { message_id, label_map, text } = await JSON.parse(req.body);
|
||||
|
||||
@@ -35,6 +27,6 @@ const handler = async (req, res) => {
|
||||
}),
|
||||
});
|
||||
res.status(interactionRes.status).end();
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
@@ -13,15 +13,7 @@ import prisma from "src/lib/prismadb";
|
||||
* 4) Records the new task in our local database.
|
||||
* 5) Returns the newly created task to the client.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, content, update_type } = await JSON.parse(req.body);
|
||||
|
||||
@@ -65,6 +57,6 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send the next task in the sequence to the client.
|
||||
res.status(200).json(newRegisteredTask);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
@@ -8,9 +7,6 @@ import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
|
||||
const AssistantReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const AssistantReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
@@ -8,9 +7,6 @@ import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
|
||||
const InitialPrompt = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const InitialPrompt = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -36,10 +36,10 @@ const SummarizeStory = () => {
|
||||
|
||||
// Trigger a mutation that updates the current task. We should probably
|
||||
// signal somewhere that this interaction is being processed.
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const submitResponse = () => {
|
||||
const text = inputText.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
id: tasks[0].task.id,
|
||||
update_type: "text_reply_to_message",
|
||||
content: {
|
||||
text,
|
||||
@@ -88,9 +88,10 @@ const SummarizeStory = () => {
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
task={tasks[0].task}
|
||||
taskStatus={"NOT_SUBMITTABLE"}
|
||||
onSubmit={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
@@ -8,9 +7,6 @@ import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
|
||||
const UserReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const UserReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
@@ -8,9 +7,6 @@ import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
const RankAssistantReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankAssistantReplies = () => {
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
@@ -8,9 +7,6 @@ import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
|
||||
const RankInitialPrompts = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankInitialPrompts = () => {
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
@@ -8,9 +7,6 @@ import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
const RankUserReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
@@ -25,7 +21,7 @@ const RankUserReplies = () => {
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -39,9 +39,9 @@ const RateSummary = () => {
|
||||
|
||||
// Trigger a mutation that updates the current task. We should probably
|
||||
// signal somewhere that this interaction is being processed.
|
||||
const submitResponse = (t) => {
|
||||
const submitResponse = () => {
|
||||
trigger({
|
||||
id: t.id,
|
||||
id: tasks[0].task.id,
|
||||
update_type: "message_rating",
|
||||
content: {
|
||||
rating: rating,
|
||||
@@ -103,9 +103,10 @@ const RateSummary = () => {
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
task={tasks[0].task}
|
||||
taskStatus={"NOT_SUBMITTABLE"}
|
||||
onSubmit={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
|
||||
@@ -1,43 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { LabelAssistantReplyTaskResponse, useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: true, message_id: task.message_id },
|
||||
];
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Assistant Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Assistant Reply</title>
|
||||
<meta name="description" content="Label Assistant Reply" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,38 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Initial Prompt"
|
||||
desc="Provide labels for the following prompt"
|
||||
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Initial Prompt</title>
|
||||
<meta name="description" content="Label Initial Prompt" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,43 +1,28 @@
|
||||
import { useState } from "react";
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
if (isLoading) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
];
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<>
|
||||
<Head>
|
||||
<title>Label Prompter Reply</title>
|
||||
<meta name="description" content="Label Prompter Reply" />
|
||||
</Head>
|
||||
<Task key={tasks[0].task.id} task={tasks[0].task} trigger={trigger} mutate={reset} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ export const enum TaskType {
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
}
|
||||
|
||||
// we need to reconsider how to handle task content types
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export type TaskContent = any;
|
||||
|
||||
export interface ValidLabel {
|
||||
name: string;
|
||||
display_text: string;
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
export interface TaskReplyValid<T> {
|
||||
content: T;
|
||||
state: "VALID";
|
||||
}
|
||||
export interface TaskReplyDefault<T> {
|
||||
content: T;
|
||||
state: "DEFAULT";
|
||||
}
|
||||
|
||||
export type TaskReplyState<T> = TaskReplyValid<T> | TaskReplyDefault<T>;
|
||||
Reference in New Issue
Block a user