From a19e0fa085d3110c956b1375ffe7162920b59d22 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 09:21:07 +0100 Subject: [PATCH 01/19] Extract generic code out of labeling task --- .../src/components/Loading/LoadingScreen.jsx | 2 +- website/src/components/Tasks/LabelTask.tsx | 100 +++++++++++++++ website/src/hooks/tasks/useGenericTaskAPI.tsx | 42 +++++++ .../src/hooks/tasks/useLabelInitialPrompt.tsx | 27 ++++ website/src/hooks/useLabelingTask.ts | 52 -------- .../src/pages/label/label_initial_prompt.tsx | 117 ++++-------------- 6 files changed, 191 insertions(+), 149 deletions(-) create mode 100644 website/src/components/Tasks/LabelTask.tsx create mode 100644 website/src/hooks/tasks/useGenericTaskAPI.tsx create mode 100644 website/src/hooks/tasks/useLabelInitialPrompt.tsx delete mode 100644 website/src/hooks/useLabelingTask.ts diff --git a/website/src/components/Loading/LoadingScreen.jsx b/website/src/components/Loading/LoadingScreen.jsx index 02aabe7a..3595b3c4 100644 --- a/website/src/components/Loading/LoadingScreen.jsx +++ b/website/src/components/Loading/LoadingScreen.jsx @@ -1,7 +1,7 @@ import { Progress } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; -export const LoadingScreen = ({ text }) => { +export const LoadingScreen = ({ text = "Loading..." } = {}) => { const { colorMode } = useColorMode(); const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; diff --git a/website/src/components/Tasks/LabelTask.tsx b/website/src/components/Tasks/LabelTask.tsx new file mode 100644 index 00000000..bb9d417c --- /dev/null +++ b/website/src/components/Tasks/LabelTask.tsx @@ -0,0 +1,100 @@ +import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; +import { useColorMode } from "@chakra-ui/react"; +import { ReactNode, useEffect, useId, useMemo, useState } from "react"; +import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; +import { colors } from "styles/Theme/colors"; + +export const LabelTask = ({ + title, + desc, + messages, + inputs, + controls, +}: { + title: string; + desc: string; + messages: ReactNode; + inputs: ReactNode; + controls: ReactNode; +}) => { + const { colorMode } = useColorMode(); + const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; + + const card = useMemo( + () => ( + <> +
{title}
+

{desc}

+ {messages} + + ), + [title, desc, messages] + ); + + return ( +
+ + {card} + {inputs} + + {controls} +
+ ); +}; + +// TODO: consolidate with FlaggableElement +interface LabelSliderGroupProps { + labelIDs: Array; + onChange: (sliderValues: number[]) => unknown; +} + +export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => { + const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0)); + + useEffect(() => { + onChange(sliderValues); + }, [sliderValues, onChange]); + + return ( + + {labelIDs.map((labelId, idx) => ( + { + const newState = sliderValues.slice(); + newState[idx] = sliderValue; + setSliderValues(newState); + }} + /> + ))} + + ); +}; + +function CheckboxSliderItem(props: { + labelId: string; + sliderValue: number; + sliderHandler: (newVal: number) => unknown; +}) { + const id = useId(); + const { colorMode } = useColorMode(); + + const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; + + return ( + <> + + props.sliderHandler(val / 100)}> + + + + + + + ); +} diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx new file mode 100644 index 00000000..1a6c0be9 --- /dev/null +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -0,0 +1,42 @@ +import { useEffect, useState } from "react"; +import fetcher from "src/lib/fetcher"; +import poster from "src/lib/poster"; +import useSWRImmutable from "swr/immutable"; +import useSWRMutation from "swr/mutation"; + +// TODO: type & centralize types for all tasks + +export interface TaskResponse { + id: string; + userId: string; + task: TaskType; +} + +export const useGenericTaskAPI = (taskApiEndpoint: string) => { + type ConcreteTaskResponse = TaskResponse; + + const [tasks, setTasks] = useState([]); + + const { isLoading, mutate, error } = useSWRImmutable( + "/api/new_task/" + taskApiEndpoint, + fetcher, + { + onSuccess: (data) => setTasks([data]), + } + ); + + useEffect(() => { + if (tasks.length === 0 && !isLoading && !error) { + mutate(); + } + }, [tasks, isLoading, mutate, error]); + + const { trigger } = useSWRMutation("/api/update_task", poster, { + onSuccess: async (response) => { + const newTask: ConcreteTaskResponse = await response.json(); + setTasks((oldTasks) => [...oldTasks, newTask]); + }, + }); + + return { tasks, isLoading, trigger, error, reset: mutate }; +}; diff --git a/website/src/hooks/tasks/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/useLabelInitialPrompt.tsx new file mode 100644 index 00000000..5d6ca372 --- /dev/null +++ b/website/src/hooks/tasks/useLabelInitialPrompt.tsx @@ -0,0 +1,27 @@ +import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI"; + +export interface LabelInitialPromptTask { + id: string; + type: "label_initial_prompt"; + message_id: string; + valid_labels: string[]; + prompt: string; +} + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelInitialPromptTask = () => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI("label_initial_prompt"); + + const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { + console.assert(validLabels.length === labelWeights.length); + const labels = validLabels.reduce( + (obj, label, i) => ((obj[label] = labelWeights[i]), obj), + {} as Record + ); + + return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + }; + + return { tasks, isLoading, submit, reset, error }; +}; diff --git a/website/src/hooks/useLabelingTask.ts b/website/src/hooks/useLabelingTask.ts deleted file mode 100644 index 872909b7..00000000 --- a/website/src/hooks/useLabelingTask.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { useEffect, useState } from "react"; -import fetcher from "src/lib/fetcher"; -import poster from "src/lib/poster"; -import useSWRImmutable from "swr/immutable"; -import useSWRMutation from "swr/mutation"; - -// TODO: type & centralize types for all tasks -interface TaskResponse { - id: string; - userId: string; - task: TaskType; -} - -export interface LabelInitialPromptTask { - id: string; - message_id: string; - prompt: string; - type: string; - valid_labels: string[]; -} - -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelingTask = ({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => { - type ConcreteTaskResponse = TaskResponse; - - const [tasks, setTasks] = useState>([]); - - const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, { - onSuccess: (data: ConcreteTaskResponse) => { - setTasks([data]); - }, - }); - - useEffect(() => { - if (tasks.length === 0 && !isLoading && !error) { - mutate(); - } - }, [tasks, isLoading, mutate, error]); - - const { trigger } = useSWRMutation("/api/update_task", poster, { - onSuccess: async (reply) => { - const newTask: ConcreteTaskResponse = await reply.json(); - setTasks((oldTasks) => [...oldTasks, newTask]); - }, - }); - - const submit = (id: string, message_id: string, text: string, labels: Record) => - trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - - return { tasks, isLoading, submit, error, reset: mutate }; -}; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index e400e8fd..346362dc 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,113 +1,38 @@ -import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; -import { useColorMode } from "@chakra-ui/react"; -import { useEffect, useId, useState } from "react"; +import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { MessageView } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; -import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; -import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask"; -import { colors } from "styles/Theme/colors"; +import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; +import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelInitialPrompt"; const LabelInitialPrompt = () => { const [sliderValues, setSliderValues] = useState([]); - const { tasks, isLoading, submit, reset } = useLabelingTask({ - taskApiEndpoint: "label_initial_prompt", - }); + const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask(); - const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => { - const labels = task.valid_labels.reduce((obj, label, i) => { - obj[label] = sliderValues[i].toString(); - return obj; - }, {} as Record); - - submit(id, task.message_id, task.prompt, labels); - }; - - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return No tasks found...; + if (isLoading || tasks.length === 0) { + return ; } const task = tasks[0].task; return ( -
- - <> -
Label Initial Prompt
-

Provide labels for the following prompt

- - - -
- -
+ } + inputs={} + controls={ + + submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues) + } + /> + } + /> ); }; export default LabelInitialPrompt; - -// TODO: consolidate with FlaggableElement - -interface CheckboxSliderGroupProps { - labelIDs: Array; - onChange: (sliderValues: number[]) => unknown; -} - -const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => { - const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0)); - - useEffect(() => { - onChange(sliderValues); - }, [sliderValues, onChange]); - - return ( - - {labelIDs.map((labelId, idx) => ( - { - const newState = sliderValues.slice(); - newState[idx] = sliderValue; - setSliderValues(newState); - }} - /> - ))} - - ); -}; - -function CheckboxSliderItem(props: { - labelId: string; - sliderValue: number; - sliderHandler: (newVal: number) => unknown; -}) { - const id = useId(); - const { colorMode } = useColorMode(); - - const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; - - return ( - <> - - props.sliderHandler(val / 100)}> - - - - - - - ); -} From f7dceee87a56c6d828b7ee4739f260d01e8738d1 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 09:49:50 +0100 Subject: [PATCH 02/19] Add LabelPrompterReplyTask --- website/src/components/Messages.tsx | 3 +- website/src/components/Tasks/TaskTypes.tsx | 7 +++ .../src/hooks/tasks/useLabelInitialPrompt.tsx | 5 +-- .../src/hooks/tasks/useLabelPrompterReply.ts | 30 +++++++++++++ website/src/middleware.ts | 4 +- .../src/pages/label/label_prompter_reply.tsx | 44 +++++++++++++++++++ 6 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 website/src/hooks/tasks/useLabelPrompterReply.ts create mode 100644 website/src/pages/label/label_prompter_reply.tsx diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 226c6154..fb84559e 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -12,8 +12,7 @@ export interface Message { export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { const items = messages.map((messageProps: Message, i: number) => { - const { message_id } = messageProps; - const { text } = messageProps; + const { message_id, text } = messageProps; return ( diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 7cec2177..09e106b6 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -63,4 +63,11 @@ export const TaskTypes = [ pathname: "/label/label_initial_prompt", type: "label_initial_prompt", }, + { + label: "Label Prompter Reply", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_prompter_reply", + type: "label_prompter_reply", + }, ]; diff --git a/website/src/hooks/tasks/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/useLabelInitialPrompt.tsx index 5d6ca372..69ab4bcc 100644 --- a/website/src/hooks/tasks/useLabelInitialPrompt.tsx +++ b/website/src/hooks/tasks/useLabelInitialPrompt.tsx @@ -15,10 +15,7 @@ export const useLabelInitialPromptTask = () => { const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { console.assert(validLabels.length === labelWeights.length); - const labels = validLabels.reduce( - (obj, label, i) => ((obj[label] = labelWeights[i]), obj), - {} as Record - ); + const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); }; diff --git a/website/src/hooks/tasks/useLabelPrompterReply.ts b/website/src/hooks/tasks/useLabelPrompterReply.ts new file mode 100644 index 00000000..9b7a61da --- /dev/null +++ b/website/src/hooks/tasks/useLabelPrompterReply.ts @@ -0,0 +1,30 @@ +import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI"; + +export interface LabelPrompterReplyTask { + id: string; + type: "label_prompter_reply"; + message_id: string; + valid_labels: string[]; + reply: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + }>; + }; +} + +export type LabelPrompterReplyTaskResponse = TaskResponse; + +export const useLabelPrompterReplyTask = () => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI("label_prompter_reply"); + + const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { + console.assert(validLabels.length === labelWeights.length); + const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); + + return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + }; + + return { tasks, isLoading, submit, reset, error }; +}; diff --git a/website/src/middleware.ts b/website/src/middleware.ts index b6a539b4..d1cd6801 100644 --- a/website/src/middleware.ts +++ b/website/src/middleware.ts @@ -1,8 +1,8 @@ export { default } from "next-auth/middleware"; /** - * Guards all pages under `/grading` and redirects them to the sign in page. + * Guards these pages and redirects them to the sign in page. */ export const config = { - matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"], + matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"], }; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx new file mode 100644 index 00000000..743bde97 --- /dev/null +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -0,0 +1,44 @@ +import { useState } from "react"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message, Messages } from "src/components/Messages"; +import { TaskControls } from "src/components/Survey/TaskControls"; +import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; +import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply"; + +const LabelPrompterReply = () => { + const [sliderValues, setSliderValues] = useState([]); + + const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask(); + + if (isLoading || tasks.length === 0) { + return ; + } + + const task = tasks[0].task; + const messages: Message[] = [ + // TODO: could we re-use the task message_id as message id for all messages in the conversation? + // or should we ask the backend team to send message ids in the task? + ...task.conversation.messages.map((m) => ({ ...m, message_id: null })), + { text: task.reply, is_assistant: false, message_id: task.message_id }, + ]; + + return ( + } + inputs={} + controls={ + + submit(id, task.message_id, task.reply, task.valid_labels, sliderValues) + } + /> + } + /> + ); +}; + +export default LabelPrompterReply; From fd9edf29d59221974cd416ec56ea54d2d9e9bf37 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 10:52:48 +0100 Subject: [PATCH 03/19] Use MessageTable --- website/src/pages/label/label_prompter_reply.tsx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 743bde97..81be8fc3 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,6 +1,7 @@ import { useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Message, Messages } from "src/components/Messages"; +import { Message } from "src/components/Messages"; +import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply"; @@ -26,7 +27,7 @@ const LabelPrompterReply = () => { } + messages={} inputs={} controls={ Date: Sun, 8 Jan 2023 10:58:24 +0100 Subject: [PATCH 04/19] Updating task schema according to backend --- website/src/hooks/tasks/useLabelPrompterReply.ts | 1 + website/src/pages/label/label_prompter_reply.tsx | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/website/src/hooks/tasks/useLabelPrompterReply.ts b/website/src/hooks/tasks/useLabelPrompterReply.ts index 9b7a61da..f048c2b3 100644 --- a/website/src/hooks/tasks/useLabelPrompterReply.ts +++ b/website/src/hooks/tasks/useLabelPrompterReply.ts @@ -10,6 +10,7 @@ export interface LabelPrompterReplyTask { messages: Array<{ text: string; is_assistant: boolean; + message_id: string; }>; }; } diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 81be8fc3..44606f47 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -17,9 +17,7 @@ const LabelPrompterReply = () => { const task = tasks[0].task; const messages: Message[] = [ - // TODO: could we re-use the task message_id as message id for all messages in the conversation? - // or should we ask the backend team to send message ids in the task? - ...task.conversation.messages.map((m) => ({ ...m, message_id: null })), + ...task.conversation.messages, { text: task.reply, is_assistant: false, message_id: task.message_id }, ]; From 5c48e34abb676dd733fb22de3642a1b21a6caa9b Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 11:32:42 +0100 Subject: [PATCH 05/19] Fix infinite fetch loop in messages if user has not submitted anything --- website/src/pages/messages/index.tsx | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index c32a90c6..2809ba5c 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha import Head from "next/head"; import { useEffect, useState } from "react"; import { getDashboardLayout } from "src/components/Layout"; +import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import fetcher from "src/lib/fetcher"; import useSWRImmutable from "swr/immutable"; @@ -10,29 +11,28 @@ const MessagesDashboard = () => { const boxBgColor = useColorModeValue("white", "gray.700"); const boxAccentColor = useColorModeValue("gray.200", "gray.900"); - const [messages, setMessages] = useState([]); - const [userMessages, setUserMessages] = useState([]); + const [messages, setMessages] = useState(null); + const [userMessages, setUserMessages] = useState(null); const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, { - onSuccess: (data) => { - setMessages(data); - }, + onSuccess: setMessages, }); const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, { - onSuccess: (data) => { - setUserMessages(data); - }, + onSuccess: setUserMessages, }); + const receivedMessages = !isLoadingAll && Array.isArray(messages); + const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages); + useEffect(() => { - if (messages.length == 0) { + if (!receivedMessages) { mutateAll(); } - if (userMessages.length == 0) { + if (!receivedUserMessages) { mutateUser(); } - }, [messages, userMessages]); + }, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]); return ( <> @@ -52,7 +52,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {isLoadingAll ? : } + {receivedMessages ? : } @@ -66,7 +66,7 @@ const MessagesDashboard = () => { borderRadius="xl" className="p-6 shadow-sm" > - {isLoadingUser ? : } + {receivedUserMessages ? : } From c9a3813b8f7622b88c663ec20c8eb9de67aa7308 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 12:07:26 +0100 Subject: [PATCH 06/19] Add Label Assistant Reply Task --- website/src/components/Tasks/TaskTypes.tsx | 6 +++ .../tasks/labeling/useLabelAssistantReply.ts | 22 +++++++++ .../tasks/labeling/useLabelInitialPrompt.tsx | 15 ++++++ .../tasks/labeling/useLabelPrompterReply.ts | 22 +++++++++ .../useLabelingTask.ts} | 18 +++----- .../src/hooks/tasks/useLabelPrompterReply.ts | 31 ------------- .../src/pages/label/label_assistant_reply.tsx | 46 +++++++++++++++++++ .../src/pages/label/label_initial_prompt.tsx | 5 +- .../src/pages/label/label_prompter_reply.tsx | 5 +- 9 files changed, 126 insertions(+), 44 deletions(-) create mode 100644 website/src/hooks/tasks/labeling/useLabelAssistantReply.ts create mode 100644 website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx create mode 100644 website/src/hooks/tasks/labeling/useLabelPrompterReply.ts rename website/src/hooks/tasks/{useLabelInitialPrompt.tsx => labeling/useLabelingTask.ts} (54%) delete mode 100644 website/src/hooks/tasks/useLabelPrompterReply.ts create mode 100644 website/src/pages/label/label_assistant_reply.tsx diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 09e106b6..9cbad58a 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -69,5 +69,11 @@ export const TaskTypes = [ category: TaskCategory.Label, pathname: "/label/label_prompter_reply", type: "label_prompter_reply", + },{ + label: "Label Assistant Reply", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_assistant_reply", + type: "label_assistant_reply", }, ]; diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts new file mode 100644 index 00000000..93f87188 --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts @@ -0,0 +1,22 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelAssistantReplyTask { + id: string; + type: "label_prompter_reply"; + message_id: string; + valid_labels: string[]; + reply: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +export type LabelAssistantReplyTaskResponse = TaskResponse; + +export const useLabelAssistantReplyTask = () => + useLabelingTask(LabelingTaskType.label_assistant_reply); diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx new file mode 100644 index 00000000..47138825 --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx @@ -0,0 +1,15 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelInitialPromptTask { + id: string; + type: "label_initial_prompt"; + message_id: string; + valid_labels: string[]; + prompt: string; +} + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelInitialPromptTask = () => + useLabelingTask(LabelingTaskType.label_initial_prompt); diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts new file mode 100644 index 00000000..634305cb --- /dev/null +++ b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts @@ -0,0 +1,22 @@ +import { TaskResponse } from "../useGenericTaskAPI"; +import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; + +export interface LabelPrompterReplyTask { + id: string; + type: "label_prompter_reply"; + message_id: string; + valid_labels: string[]; + reply: string; + conversation: { + messages: Array<{ + text: string; + is_assistant: boolean; + message_id: string; + }>; + }; +} + +export type LabelPrompterReplyTaskResponse = TaskResponse; + +export const useLabelPrompterReplyTask = () => + useLabelingTask(LabelingTaskType.label_prompter_reply); diff --git a/website/src/hooks/tasks/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelingTask.ts similarity index 54% rename from website/src/hooks/tasks/useLabelInitialPrompt.tsx rename to website/src/hooks/tasks/labeling/useLabelingTask.ts index 69ab4bcc..27555284 100644 --- a/website/src/hooks/tasks/useLabelInitialPrompt.tsx +++ b/website/src/hooks/tasks/labeling/useLabelingTask.ts @@ -1,17 +1,13 @@ -import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI"; +import { useGenericTaskAPI } from "../useGenericTaskAPI"; -export interface LabelInitialPromptTask { - id: string; - type: "label_initial_prompt"; - message_id: string; - valid_labels: string[]; - prompt: string; +export const enum LabelingTaskType { + label_initial_prompt = "label_initial_prompt", + label_prompter_reply = "label_prompter_reply", + label_assistant_reply = "label_assistant_reply", } -export type LabelInitialPromptTaskResponse = TaskResponse; - -export const useLabelInitialPromptTask = () => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI("label_initial_prompt"); +export const useLabelingTask = (endpoint: LabelingTaskType) => { + const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint); const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { console.assert(validLabels.length === labelWeights.length); diff --git a/website/src/hooks/tasks/useLabelPrompterReply.ts b/website/src/hooks/tasks/useLabelPrompterReply.ts deleted file mode 100644 index f048c2b3..00000000 --- a/website/src/hooks/tasks/useLabelPrompterReply.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { TaskResponse, useGenericTaskAPI } from "./useGenericTaskAPI"; - -export interface LabelPrompterReplyTask { - id: string; - type: "label_prompter_reply"; - message_id: string; - valid_labels: string[]; - reply: string; - conversation: { - messages: Array<{ - text: string; - is_assistant: boolean; - message_id: string; - }>; - }; -} - -export type LabelPrompterReplyTaskResponse = TaskResponse; - -export const useLabelPrompterReplyTask = () => { - const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI("label_prompter_reply"); - - const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => { - console.assert(validLabels.length === labelWeights.length); - const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]])); - - return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); - }; - - return { tasks, isLoading, submit, reset, error }; -}; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx new file mode 100644 index 00000000..d314a907 --- /dev/null +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -0,0 +1,46 @@ +import { useState } from "react"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Message } from "src/components/Messages"; +import { MessageTable } from "src/components/Messages/MessageTable"; +import { TaskControls } from "src/components/Survey/TaskControls"; +import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; +import { + LabelAssistantReplyTaskResponse, + useLabelAssistantReplyTask, +} from "src/hooks/tasks/labeling/useLabelAssistantReply"; + +const LabelAssistantReply = () => { + const [sliderValues, setSliderValues] = useState([]); + + const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask(); + + if (isLoading || tasks.length === 0) { + return ; + } + + const task = tasks[0].task; + const messages: Message[] = [ + ...task.conversation.messages, + { text: task.reply, is_assistant: true, message_id: task.message_id }, + ]; + + return ( + } + inputs={} + controls={ + + submit(id, task.message_id, task.reply, task.valid_labels, sliderValues) + } + /> + } + /> + ); +}; + +export default LabelAssistantReply; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 346362dc..7d1c606b 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -3,7 +3,10 @@ import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { MessageView } from "src/components/Messages"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelInitialPromptTaskResponse, useLabelInitialPromptTask } from "src/hooks/tasks/useLabelInitialPrompt"; +import { + LabelInitialPromptTaskResponse, + useLabelInitialPromptTask, +} from "src/hooks/tasks/labeling/useLabelInitialPrompt"; const LabelInitialPrompt = () => { const [sliderValues, setSliderValues] = useState([]); diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 44606f47..b5742b23 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -4,7 +4,10 @@ import { Message } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; import { TaskControls } from "src/components/Survey/TaskControls"; import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask"; -import { LabelPrompterReplyTaskResponse, useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelPrompterReply"; +import { + LabelPrompterReplyTaskResponse, + useLabelPrompterReplyTask, +} from "src/hooks/tasks/labeling/useLabelPrompterReply"; const LabelPrompterReply = () => { const [sliderValues, setSliderValues] = useState([]); From 39485c6cedefa962a8378ff04a55730e7cd3c654 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 8 Jan 2023 20:11:26 +0900 Subject: [PATCH 07/19] Adding simple pagination to the admin user's view --- website/src/components/UsersCell.tsx | 90 +++++++++++++++++++--------- website/src/pages/api/admin/users.ts | 13 +++- 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/website/src/components/UsersCell.tsx b/website/src/components/UsersCell.tsx index 3cf2f1e5..fc93215f 100644 --- a/website/src/components/UsersCell.tsx +++ b/website/src/components/UsersCell.tsx @@ -1,4 +1,17 @@ -import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react"; +import { + Button, + Flex, + Spacer, + Stack, + Table, + TableCaption, + TableContainer, + Tbody, + Td, + Th, + Thead, + Tr, +} from "@chakra-ui/react"; import Link from "next/link"; import { useState } from "react"; import fetcher from "src/lib/fetcher"; @@ -8,41 +21,60 @@ import useSWR from "swr"; * Fetches users from the users api route and then presents them in a simple Chakra table. */ const UsersCell = () => { - // Fetch and save the users. + const [pageIndex, setPageIndex] = useState(0); const [users, setUsers] = useState([]); - const { isLoading } = useSWR("/api/admin/users", fetcher, { + + // Fetch and save the users. + // This follows useSWR's recommendation for simple pagination: + // https://swr.vercel.app/docs/pagination#when-to-use-useswr + const { isLoading } = useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, { onSuccess: setUsers, }); + const toPreviousPage = () => { + setPageIndex(Math.max(0, pageIndex - 1)); + }; + + const toNextPage = () => { + setPageIndex(pageIndex + 1); + }; + // Present users in a naive table. return ( - - - Users - - - - - - - - - - - {users.map((user, index) => ( - - - - - - + + + + + + + +
IdEmailNameRoleUpdate
{user.id}{user.email}{user.name}{user.role} - Manage -
+ Users + + + + + + + - ))} - -
IdEmailNameRoleUpdate
-
+ + + {users.map((user, index) => ( + + {user.id} + {user.email} + {user.name} + {user.role} + + Manage + + + ))} + + + + ); }; diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 1490522a..3418e5e1 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -2,11 +2,21 @@ import { getToken } from "next-auth/jwt"; import withRole from "src/lib/auth"; import prisma from "src/lib/prismadb"; +// The number of users to fetch in any request. +const PAGE_SIZE = 20; + /** * Returns a list of user results from the database when the requesting user is * a logged in admin. */ const handler = withRole("admin", async (req, res) => { + // Figure out the pagination index and skip that number of users. + // + // Note: with Prisma this isn't the most efficient but it's the only possible + // option with cuid based User IDs. + const { pageIndex } = req.query; + const skip = pageIndex * PAGE_SIZE; + // Fetch 20 users. const users = await prisma.user.findMany({ select: { @@ -15,7 +25,8 @@ const handler = withRole("admin", async (req, res) => { name: true, email: true, }, - take: 20, + skip, + take: PAGE_SIZE, }); res.status(200).json(users); From d99c0e2d2f2cb40e6f15314ddb1059a1f8963bd0 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 8 Jan 2023 12:11:47 +0100 Subject: [PATCH 08/19] Use enum for type definition --- website/src/components/Tasks/TaskTypes.tsx | 3 ++- website/src/hooks/tasks/labeling/useLabelAssistantReply.ts | 2 +- website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx | 2 +- website/src/hooks/tasks/labeling/useLabelPrompterReply.ts | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 9cbad58a..82ba1917 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -69,7 +69,8 @@ export const TaskTypes = [ category: TaskCategory.Label, pathname: "/label/label_prompter_reply", type: "label_prompter_reply", - },{ + }, + { label: "Label Assistant Reply", desc: "Provide labels for a prompt.", category: TaskCategory.Label, diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts index 93f87188..3c44046e 100644 --- a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts +++ b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts @@ -3,7 +3,7 @@ import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; export interface LabelAssistantReplyTask { id: string; - type: "label_prompter_reply"; + type: LabelingTaskType.label_assistant_reply; message_id: string; valid_labels: string[]; reply: string; diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx index 47138825..f7ba8ab5 100644 --- a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx +++ b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx @@ -3,7 +3,7 @@ import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; export interface LabelInitialPromptTask { id: string; - type: "label_initial_prompt"; + type: LabelingTaskType.label_initial_prompt; message_id: string; valid_labels: string[]; prompt: string; diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts index 634305cb..9de2057f 100644 --- a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts +++ b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts @@ -3,7 +3,7 @@ import { LabelingTaskType, useLabelingTask } from "./useLabelingTask"; export interface LabelPrompterReplyTask { id: string; - type: "label_prompter_reply"; + type: LabelingTaskType.label_prompter_reply; message_id: string; valid_labels: string[]; reply: string; From e69715fbec9d3a52969b6cfda69ef1ba8bcf5f6e Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 8 Jan 2023 20:21:13 +0900 Subject: [PATCH 09/19] Fixing some lint errors with the new admin features --- website/src/components/UsersCell.tsx | 2 +- website/src/lib/auth.ts | 2 +- website/src/pages/admin/index.tsx | 2 +- website/src/pages/admin/manage_user/[id].tsx | 4 ++-- website/src/pages/api/admin/update_user.ts | 1 - website/src/pages/api/admin/users.ts | 3 +-- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/website/src/components/UsersCell.tsx b/website/src/components/UsersCell.tsx index fc93215f..5354ee5c 100644 --- a/website/src/components/UsersCell.tsx +++ b/website/src/components/UsersCell.tsx @@ -27,7 +27,7 @@ const UsersCell = () => { // Fetch and save the users. // This follows useSWR's recommendation for simple pagination: // https://swr.vercel.app/docs/pagination#when-to-use-useswr - const { isLoading } = useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, { + useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, { onSuccess: setUsers, }); diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts index 1a0387f9..5fa20f48 100644 --- a/website/src/lib/auth.ts +++ b/website/src/lib/auth.ts @@ -5,7 +5,7 @@ import { getToken } from "next-auth/jwt"; * Wraps any API Route handler and verifies that the user has the appropriate * role before running the handler. Returns a 403 otherwise. */ -const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => any) => { +const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role !== role) { diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index 705a188b..9cbea222 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -26,7 +26,7 @@ const AdminIndex = () => { return; } router.push("/"); - }, [session, status]); + }, [router, session, status]); return ( <> diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index ead55224..cdd4746e 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -1,4 +1,4 @@ -import { Box, Button, Container, Flex, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react"; +import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react"; import { Field, Form, Formik } from "formik"; import Head from "next/head"; import { useRouter } from "next/router"; @@ -27,7 +27,7 @@ const ManageUser = ({ user }) => { return; } router.push("/"); - }, [session, status]); + }, [router, session, status]); // Trigger to let us update the user's role. Triggers a toast when complete. const { trigger } = useSWRMutation("/api/admin/update_user", poster, { diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts index d29fce7c..a717e3d8 100644 --- a/website/src/pages/api/admin/update_user.ts +++ b/website/src/pages/api/admin/update_user.ts @@ -1,4 +1,3 @@ -import { getToken } from "next-auth/jwt"; import withRole from "src/lib/auth"; import prisma from "src/lib/prismadb"; diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 3418e5e1..ea8d59d9 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,4 +1,3 @@ -import { getToken } from "next-auth/jwt"; import withRole from "src/lib/auth"; import prisma from "src/lib/prismadb"; @@ -15,7 +14,7 @@ const handler = withRole("admin", async (req, res) => { // Note: with Prisma this isn't the most efficient but it's the only possible // option with cuid based User IDs. const { pageIndex } = req.query; - const skip = pageIndex * PAGE_SIZE; + const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0; // Fetch 20 users. const users = await prisma.user.findMany({ From 78fac2b5f51d2ffc6e3d5b2cc7658676d8cc8ae6 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sun, 8 Jan 2023 11:44:57 +0000 Subject: [PATCH 10/19] Refine label for harmful --- oasst-shared/oasst_shared/schemas/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 190ffa43..ee021f14 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -276,10 +276,10 @@ class TextLabel(str, enum.Enum): fails_task = "fails_task", "Fails to follow the correct instruction / task" not_appropriate = "not_appropriate", "Inappropriate for customer assistant" violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm" - harmful = ( - "harmful", - "Harmful content", - "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the label for encouraging violence/abuse/terrorism/self-harm.", + excessive_harm = ( + "excessive_harm", + "Content likely to cause excessive harm not justifiable in the context", + "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", ) sexual_content = "sexual_content", "Contains sexual content" toxicity = "toxicity", "Contains rude, abusive, profane or insulting content" From 570d39edec84d8d10642f071f5726182bde70a3b Mon Sep 17 00:00:00 2001 From: Oliver Date: Sun, 8 Jan 2023 11:51:32 +0000 Subject: [PATCH 11/19] Refine label for hate_speech --- oasst-shared/oasst_shared/schemas/protocol.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 190ffa43..530c4f45 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -286,7 +286,11 @@ class TextLabel(str, enum.Enum): moral_judgement = "moral_judgement", "Expresses moral judgement" political_content = "political_content", "Expresses political views" humor = "humor", "Contains humorous content including sarcasm" - hate_speech = "hate_speech", "Expresses sentiment which is discriminatory against a grouping of people" + hate_speech = ( + "hate_speech", + "Content is abusive or threatening and expresses prejudice against a protected characteristic", + "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + ) threat = "threat", "Contains a threat against a person or persons" misleading = "misleading", "Contains text which is incorrect or misleading" helpful = "helpful", "Completes the task to a high standard" From b0952bc6819e82806e28653405727f4139d1e8d8 Mon Sep 17 00:00:00 2001 From: Callum Date: Sun, 8 Jan 2023 13:33:29 +0000 Subject: [PATCH 12/19] #224: display OAuthAccountNotLinked error message + other errors messages --- website/src/pages/auth/signin.tsx | 50 ++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index 59fc7c05..14f81b4c 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -2,17 +2,60 @@ import { Button, Input, Stack } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; +import { useRouter } from "next/router"; import { getCsrfToken, getProviders, signIn } from "next-auth/react"; -import React, { useRef } from "react"; +import React, { useRef, useEffect, useState } from "react"; import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; import { AuthLayout } from "src/components/AuthLayout"; import { Footer } from "src/components/Footer"; import { Header } from "src/components/Header"; +export type SignInErrorTypes = + | "Signin" + | "OAuthSignin" + | "OAuthCallback" + | "OAuthCreateAccount" + | "EmailCreateAccount" + | "Callback" + | "OAuthAccountNotLinked" + | "EmailSignin" + | "CredentialsSignin" + | "SessionRequired" + | "default"; + +const errors: Record = { + Signin: "Try signing in with a different account.", + OAuthSignin: "Try signing in with a different account.", + OAuthCallback: "Try signing in with the same account you used originally.", + OAuthCreateAccount: "Try signing in with a different account.", + EmailCreateAccount: "Try signing in with a different account.", + Callback: "Try signing in with a different account.", + OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.", + EmailSignin: "The e-mail could not be sent.", + CredentialsSignin: "Sign in failed. Check the details you provided are correct.", + SessionRequired: "Please sign in to access this page.", + default: "Unable to sign in.", +}; + // eslint-disable-next-line @typescript-eslint/no-unused-vars function Signin({ csrfToken, providers }) { + const router = useRouter(); const { discord, email, github, credentials } = providers; const emailEl = useRef(null); + const [error, setError] = useState(""); + + useEffect(() => { + if (router?.query?.error) { + if (typeof router.query.error === "string") { + const errorType = errors[router.query.error]; + setError(errorType); + } else { + const errorType = errors[router.query.error[0]]; + setError(errorType); + } + } + }, [router]); + const signinWithEmail = (ev: React.FormEvent) => { ev.preventDefault(); signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value }); @@ -110,6 +153,11 @@ function Signin({ csrfToken, providers }) { . + {error && ( +
+

Error: {error}

+
+ )} ); From c4b55d2f1317a1ade6cf22198b162c7e304cca0d Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 8 Jan 2023 16:51:59 +0100 Subject: [PATCH 13/19] updated pre-commit workflow to post advice to users --- .github/workflows/pre-commit.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 0f747b45..0f82185f 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -16,3 +16,12 @@ jobs: with: python-version: "3.10" - uses: pre-commit/action@v3.0.0 + - name: Post PR comment on failure + if: failure() && github.event_name == 'pull_request' + uses: peter-evans/create-or-update-comment@v2 + with: + issue-number: ${{ github.event.pull_request.number }} + body: | + :x: **pre-commit** failed. + Please run `pre-commit run --all-files` locally and commit the changes. + Find more information in the repository's CONTRIBUTING.md From 4769db59c299ac32174be666a71f357be1ae077a Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 8 Jan 2023 16:53:48 +0100 Subject: [PATCH 14/19] added pre-commit guides and purposely did not run pre-commit --- CONTRIBUTING.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 608afe25..54e13418 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,9 @@ The website is built using Next.js and is in the `website` folder. ### Pre-commit -Install `pre-commit` and run `pre-commit install` to install the pre-commit +We are using `pre-commit` to enforce code style and formatting. + +Install `pre-commit` from [its website](https://pre-commit.com) and run `pre-commit install` to install the pre-commit hooks. In case you haven't done this, have already committed, and CI is failing, you From dc5a7d5c10bafae7c246dd72850db4fbe9826e5f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 8 Jan 2023 16:57:17 +0100 Subject: [PATCH 15/19] ran pre-commit --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 54e13418..428f6a50 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -98,8 +98,8 @@ The website is built using Next.js and is in the `website` folder. We are using `pre-commit` to enforce code style and formatting. -Install `pre-commit` from [its website](https://pre-commit.com) and run `pre-commit install` to install the pre-commit -hooks. +Install `pre-commit` from [its website](https://pre-commit.com) and run +`pre-commit install` to install the pre-commit hooks. In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files. From 35292e800a2e444e773cb62458819189ce95876b Mon Sep 17 00:00:00 2001 From: Desmond Grealy Date: Sun, 8 Jan 2023 08:35:58 -0800 Subject: [PATCH 16/19] Add theme to verify page. Simplify --- website/src/pages/auth/verify.tsx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/website/src/pages/auth/verify.tsx b/website/src/pages/auth/verify.tsx index e004f504..b4d7d739 100644 --- a/website/src/pages/auth/verify.tsx +++ b/website/src/pages/auth/verify.tsx @@ -1,17 +1,23 @@ +import { useColorMode } from "@chakra-ui/react"; import Head from "next/head"; import { getCsrfToken, getProviders } from "next-auth/react"; import { AuthLayout } from "src/components/AuthLayout"; export default function Verify() { + const { colorMode } = useColorMode(); + const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900"; + return ( <> Sign Up - Open Assistant - -

A sign-in link has been sent to your email address.

-
+
+
+

A sign-in link has been sent to your email address.

+
+
); } From 0ec2d7fb0564e4cc133a9b43d3c1eb3c1eec832c Mon Sep 17 00:00:00 2001 From: Kostia Date: Sun, 8 Jan 2023 19:24:08 +0200 Subject: [PATCH 17/19] Add /tasks/all route to website --- website/src/pages/tasks/all.tsx | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 website/src/pages/tasks/all.tsx diff --git a/website/src/pages/tasks/all.tsx b/website/src/pages/tasks/all.tsx new file mode 100644 index 00000000..6e4e926b --- /dev/null +++ b/website/src/pages/tasks/all.tsx @@ -0,0 +1,19 @@ +import Head from "next/head"; +import { TaskOption } from "src/components/Dashboard"; +import { getDashboardLayout } from "src/components/Layout"; + +const AllTasks = () => { + return ( + <> + + All Tasks - Open Assistant + + + + + ); +}; + +AllTasks.getLayout = (page) => getDashboardLayout(page); + +export default AllTasks; From ece8f227c2dd14551da6b1be895425063b13db88 Mon Sep 17 00:00:00 2001 From: Callum Date: Sun, 8 Jan 2023 17:47:08 +0000 Subject: [PATCH 18/19] #224: Simplified error useEffect and renamed errors to errorMessages --- website/src/pages/auth/signin.tsx | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index 14f81b4c..9a1d91a8 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -4,7 +4,7 @@ import Head from "next/head"; import Link from "next/link"; import { useRouter } from "next/router"; import { getCsrfToken, getProviders, signIn } from "next-auth/react"; -import React, { useRef, useEffect, useState } from "react"; +import React, { useEffect, useRef, useState } from "react"; import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; import { AuthLayout } from "src/components/AuthLayout"; import { Footer } from "src/components/Footer"; @@ -23,7 +23,7 @@ export type SignInErrorTypes = | "SessionRequired" | "default"; -const errors: Record = { +const errorMessages: Record = { Signin: "Try signing in with a different account.", OAuthSignin: "Try signing in with a different account.", OAuthCallback: "Try signing in with the same account you used originally.", @@ -45,13 +45,12 @@ function Signin({ csrfToken, providers }) { const [error, setError] = useState(""); useEffect(() => { - if (router?.query?.error) { - if (typeof router.query.error === "string") { - const errorType = errors[router.query.error]; - setError(errorType); + const err = router?.query?.error; + if (err) { + if (typeof err === "string") { + setError(errorMessages[err]); } else { - const errorType = errors[router.query.error[0]]; - setError(errorType); + setError(errorMessages[err[0]]); } } }, [router]); From 8906854dbfbf0a3f9bb0c9ce2e53d0f996f534c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 8 Jan 2023 19:08:47 +0100 Subject: [PATCH 19/19] Extract UserRepository and TaskRepository from PromptRepository * Extract classes UserRepository and TaskRepository from PromptRepository * move close_task() to TaskRepository and get_user_leaderboard to UserRepository() * Use UserRepository in leaderboards endpoint, add type annotation to leaderboards endpoint --- backend/main.py | 17 +- .../oasst_backend/api/v1/frontend_messages.py | 14 +- .../oasst_backend/api/v1/frontend_users.py | 4 +- backend/oasst_backend/api/v1/leaderboards.py | 15 +- backend/oasst_backend/api/v1/messages.py | 18 +- backend/oasst_backend/api/v1/stats.py | 2 +- backend/oasst_backend/api/v1/tasks.py | 20 +- backend/oasst_backend/api/v1/text_labels.py | 2 +- backend/oasst_backend/api/v1/users.py | 4 +- .../models/message_tree_state.py | 45 ++- backend/oasst_backend/prompt_repository.py | 326 ++++-------------- backend/oasst_backend/task_repository.py | 199 +++++++++++ backend/oasst_backend/user_repository.py | 64 ++++ 13 files changed, 409 insertions(+), 321 deletions(-) create mode 100644 backend/oasst_backend/task_repository.py create mode 100644 backend/oasst_backend/user_repository.py diff --git a/backend/main.py b/backend/main.py index 1c93fc9f..b84a2d9e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router from oasst_backend.config import settings from oasst_backend.database import engine -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA: with Session(engine) as db: api_client = get_dummy_api_client(db) dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - pr = PromptRepository(db=db, api_client=api_client, user=dummy_user) + + ur = UserRepository(db=db, api_client=api_client) + tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur) + pr = PromptRepository( + db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr + ) with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: dummy_messages_raw = json.load(f) @@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA: dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] for msg in dummy_messages: - task = pr.fetch_task_by_frontend_message_id(msg.task_message_id) + task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) if task and not task.ack: logger.warning("Deleting unacknowledged seed data task") db.delete(task) task = None if not task: if msg.parent_message_id is None: - task = pr.store_task( + task = tr.store_task( protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: @@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA: for cmsg in conversation_messages ] ) - task = pr.store_task( + task = tr.store_task( protocol_schema.AssistantReplyTask(conversation=conversation), message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) - pr.bind_frontend_message_id(task.id, msg.task_message_id) + tr.bind_frontend_message_id(task.id, msg.task_message_id) message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id) logger.info( diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index 420f0d1b..f149bebb 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -16,7 +16,7 @@ def get_message_by_frontend_id( """ Get a message by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) return utils.prepare_message(message) @@ -29,7 +29,7 @@ def get_conv_by_frontend_id( Get a conversation from the tree root and up to the message with given frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_conversation(message) return utils.prepare_conversation(messages) @@ -43,7 +43,7 @@ def get_tree_by_frontend_id( Get all messages belonging to the same message tree. Message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) tree = pr.fetch_message_tree(message.message_tree_id) return utils.prepare_tree(tree, message.message_tree_id) @@ -56,7 +56,7 @@ def get_children_by_frontend_id( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_children(message.id) return utils.prepare_message_list(messages) @@ -70,7 +70,7 @@ def get_descendants_by_frontend_id( Get a subtree which starts with this message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id( Get the longest conversation from the tree of the message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -98,7 +98,7 @@ def get_max_children_by_frontend_id( Get message with the most children from the tree of the provided message. The message is identified by its frontend ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message_by_frontend_message_id(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 0a745462..8d56b7f9 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -29,7 +29,7 @@ def query_frontend_user_messages( """ Query frontend user messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( username=username, api_client_id=api_client_id, @@ -47,6 +47,6 @@ def query_frontend_user_messages( def mark_frontend_user_messages_deleted( username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) messages = pr.query_messages(username=username, api_client_id=api_client.id) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py index 4202edad..46aea637 100644 --- a/backend/oasst_backend/api/v1/leaderboards.py +++ b/backend/oasst_backend/api/v1/leaderboards.py @@ -1,7 +1,8 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository +from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session router = APIRouter() @@ -11,15 +12,15 @@ router = APIRouter() def get_assistant_leaderboard( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), -): - pr = PromptRepository(db, api_client, None) - return pr.get_user_leaderboard(role="assistant") +) -> LeaderboardStats: + ur = UserRepository(db, api_client) + return ur.get_user_leaderboard(role="assistant") @router.get("/create/prompter") def get_prompter_leaderboard( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), -): - pr = PromptRepository(db, api_client, None) - return pr.get_user_leaderboard(role="prompter") +) -> LeaderboardStats: + ur = UserRepository(db, api_client) + return ur.get_user_leaderboard(role="prompter") diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 7a2fd2e9..6229e20c 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -29,7 +29,7 @@ def query_messages( """ Query messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( username=username, api_client_id=api_client_id, @@ -51,7 +51,7 @@ def get_message( """ Get a message by its internal ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) return utils.prepare_message(message) @@ -64,7 +64,7 @@ def get_conv( Get a conversation from the tree root and up to the message with given internal ID. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.fetch_message_conversation(message_id) return utils.prepare_conversation(messages) @@ -76,7 +76,7 @@ def get_tree( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) tree = pr.fetch_message_tree(message.message_tree_id) return utils.prepare_tree(tree, message.message_tree_id) @@ -89,7 +89,7 @@ def get_children( """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.fetch_message_children(message_id) return utils.prepare_message_list(messages) @@ -101,7 +101,7 @@ def get_descendants( """ Get a subtree which starts with this message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -114,7 +114,7 @@ def get_longest_conv( """ Get the longest conversation from the tree of the message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -127,7 +127,7 @@ def get_max_children( """ Get message with the most children from the tree of the provided message. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) message = pr.fetch_message(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) @@ -137,5 +137,5 @@ def get_max_children( def mark_message_deleted( message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) pr.mark_messages_deleted(message_id) diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py index a54aa07b..1aaffb1b 100644 --- a/backend/oasst_backend/api/v1/stats.py +++ b/backend/oasst_backend/api/v1/stats.py @@ -13,5 +13,5 @@ def get_message_stats( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) return pr.get_stats() diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index adfb2907..eb10dc00 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.api.v1.utils import prepare_conversation -from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.prompt_repository import PromptRepository, TaskRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -190,9 +190,9 @@ def request_task( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, request.user) + pr = PromptRepository(db, api_client, client_user=request.user) task, message_tree_id, parent_message_id = generate_task(request, pr) - pr.store_task(task, message_tree_id, parent_message_id, request.collective) + pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective) except OasstError: raise @@ -217,11 +217,11 @@ def tasks_acknowledge( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) # here we store the message id in the database for the task logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") - pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id) + pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id) except OasstError: raise @@ -245,8 +245,8 @@ def tasks_acknowledge_failure( try: logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.") api_client = deps.api_auth(api_key, db) - pr = PromptRepository(db, api_client, user=None) - pr.acknowledge_task_failure(task_id) + pr = PromptRepository(db, api_client) + pr.task_repository.acknowledge_task_failure(task_id) except (KeyError, RuntimeError): logger.exception("Failed to not acknowledge task.") raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED) @@ -265,7 +265,7 @@ def tasks_interaction( api_client = deps.api_auth(api_key, db) try: - pr = PromptRepository(db, api_client, user=interaction.user) + pr = PromptRepository(db, api_client, client_user=interaction.user) match type(interaction): case protocol_schema.TextReplyToMessage: @@ -323,6 +323,6 @@ def close_collective_task( api_key: APIKey = Depends(deps.get_api_key), ): api_client = deps.api_auth(api_key, db) - pr = PromptRepository(db, api_client, user=None) - pr.close_task(close_task_request.message_id) + tr = TaskRepository(db, api_client) + tr.close_task(close_task_request.message_id) return protocol_schema.TaskDone() diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index 03fd2cb4..c9afd88c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -25,7 +25,7 @@ def label_text( try: logger.info(f"Labeling text {text_labels=}.") - pr = PromptRepository(db, api_client, user=text_labels.user) + pr = PromptRepository(db, api_client, client_user=text_labels.user) pr.store_text_labels(text_labels) except Exception: diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 8d55bfec..5dda88eb 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -29,7 +29,7 @@ def query_user_messages( """ Query user messages. """ - pr = PromptRepository(db, api_client, user=None) + pr = PromptRepository(db, api_client) messages = pr.query_messages( user_id=user_id, api_client_id=api_client_id, @@ -48,6 +48,6 @@ def query_user_messages( def mark_user_messages_deleted( user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): - pr = PromptRepository(db, api_client, None) + pr = PromptRepository(db, api_client) messages = pr.query_messages(user_id=user_id) pr.mark_messages_deleted(messages) diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 386595e9..97ad34eb 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -6,27 +6,56 @@ import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg from sqlmodel import Field, Index, SQLModel -# The types of States a message tree can have. +class States(str, Enum): + """States of the Open-Assistant message tree state machine.""" + + INITIAL_PROMPT_REVIEW = "initial_prompt_review" + """In this state the message tree consists only of a single inital prompt root node. + Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or + `aborted_low_grade`.""" -class States(Enum): - INITIAL = "initial" BREEDING_PHASE = "breeding_phase" + """Assistant & prompter human demonstrations are collected. Concurrently labeling tasks + are handed out to check if the quality of the replies surpasses the minimum acceptable + quality. + When the required number of messages passing the initial labelling-quality check has been + collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses + are received the tree can also enter the `aborted_low_grade` state.""" + RANKING_PHASE = "ranking_phase" + """The tree has been successfully populated with the desired number of messages. Ranking + tasks are now handed out for all nodes with more than one child.""" + READY_FOR_SCORING = "ready_for_scoring" - CHILDREN_SCORED = "children_scored" - FINAL = "final" + """Required ranking responses have been collected and the scoring algorithm can now + compute the aggergated ranking scores that will appear in the dataset.""" + + READY_FOR_EXPORT = "ready_for_export" + """The Scoring algorithm computed rankings scores for all childern. The message tree can be + exported as part of an Open-Assistant message tree dataset.""" + + SCORING_FAILED = "scoring_failed" + """An exception occured in the scoring algorithm.""" + + ABORTED_LOW_GRADE = "aborted_low_grade" + """The system received too many bad reviews and stopped handing out tasks for this message tree.""" + + HALTED_BY_MODERATOR = "halted_by_moderator" + """A moderator decided to manually halt the message tree construction process.""" VALID_STATES = ( - States.INITIAL, + States.INITIAL_PROMPT_REVIEW, States.BREEDING_PHASE, States.RANKING_PHASE, States.READY_FOR_SCORING, - States.CHILDREN_SCORED, - States.FINAL, + States.READY_FOR_EXPORT, + States.ABORTED_LOW_GRADE, ) +TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR) + class MessageTreeState(SQLModel, table=True): __tablename__ = "message_tree_state" diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7c7dd7b6..7446ec07 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -8,98 +8,39 @@ from uuid import UUID, uuid4 import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User +from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id +from oasst_backend.user_repository import UserRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats +from oasst_shared.schemas.protocol import SystemStats from sqlalchemy import update from sqlmodel import Session, func from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND class PromptRepository: - def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]): + def __init__( + self, + db: Session, + api_client: ApiClient, + client_user: Optional[protocol_schema.User] = None, + user_repository: Optional[UserRepository] = None, + task_repository: Optional[TaskRepository] = None, + ): self.db = db self.api_client = api_client - self.user = self.lookup_user(user) + self.user_repository = user_repository or UserRepository(db, api_client) + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) self.user_id = self.user.id if self.user else None + self.task_repository = task_repository or TaskRepository( + db, api_client, client_user, user_repository=self.user_repository + ) self.journal = JournalWriter(db, api_client, self.user) - def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]: - if not client_user: - return None - user: User = ( - self.db.query(User) - .filter( - User.api_client_id == self.api_client.id, - User.username == client_user.id, - User.auth_method == client_user.auth_method, - ) - .first() - ) - if user is None: - # user is unknown, create new record - user = User( - username=client_user.id, - display_name=client_user.display_name, - api_client_id=self.api_client.id, - auth_method=client_user.auth_method, - ) - self.db.add(user) - self.db.commit() - self.db.refresh(user) - elif client_user.display_name and client_user.display_name != user.display_name: - # we found the user but the display name changed - user.display_name = client_user.display_name - self.db.add(user) - self.db.commit() - return user - - def validate_frontend_message_id(self, message_id: str) -> None: - # TODO: Should it be replaced with fastapi/pydantic validation? - if not isinstance(message_id, str): - raise OasstError( - f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID - ) - if not message_id: - raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) - - def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): - self.validate_frontend_message_id(frontend_message_id) - - # find task - task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() - if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) - if task.expired: - raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if task.done or task.ack is not None: - raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) - - task.frontend_message_id = frontend_message_id - task.ack = True - # ToDo: check race-condition, transaction - self.db.add(task) - self.db.commit() - - def acknowledge_task_failure(self, task_id): - # find task - task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() - if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) - if task.expired: - raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if task.done or task.ack is not None: - raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) - - task.ack = False - # ToDo: check race-condition, transaction - self.db.add(task) - self.db.commit() - def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message: - self.validate_frontend_message_id(frontend_message_id) + validate_frontend_message_id(frontend_message_id) message: Message = ( self.db.query(Message) .filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id) @@ -113,20 +54,48 @@ class PromptRepository: ) return message - def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: - self.validate_frontend_message_id(message_id) - task = ( - self.db.query(Task) - .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id) - .one_or_none() + def insert_message( + self, + *, + message_id: UUID, + frontend_message_id: str, + parent_id: UUID, + message_tree_id: UUID, + task_id: UUID, + role: str, + payload: db_payload.MessagePayload, + payload_type: str = None, + depth: int = 0, + ) -> Message: + if payload_type is None: + if payload is None: + payload_type = "null" + else: + payload_type = type(payload).__name__ + + message = Message( + id=message_id, + parent_id=parent_id, + message_tree_id=message_tree_id, + task_id=task_id, + user_id=self.user_id, + role=role, + frontend_message_id=frontend_message_id, + api_client_id=self.api_client.id, + payload_type=payload_type, + payload=PayloadContainer(payload=payload), + depth=depth, ) - return task + self.db.add(message) + self.db.commit() + self.db.refresh(message) + return message def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message: - self.validate_frontend_message_id(frontend_message_id) - self.validate_frontend_message_id(user_frontend_message_id) + validate_frontend_message_id(frontend_message_id) + validate_frontend_message_id(user_frontend_message_id) - task = self.fetch_task_by_frontend_message_id(frontend_message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) if task is None: raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) @@ -174,7 +143,7 @@ class PromptRepository: def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) - task = self.fetch_task_by_frontend_message_id(rating.message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id) task_payload: db_payload.RateSummaryPayload = task.payload.payload if type(task_payload) != db_payload.RateSummaryPayload: raise OasstError( @@ -201,7 +170,7 @@ class PromptRepository: def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction: # fetch task - task = self.fetch_task_by_frontend_message_id(ranking.message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id) if not task.collective: task.done = True self.db.add(task) @@ -255,142 +224,6 @@ class PromptRepository: OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH, ) - def store_task( - self, - task: protocol_schema.Task, - message_tree_id: UUID = None, - parent_message_id: UUID = None, - collective: bool = False, - ) -> Task: - payload: db_payload.TaskPayload - match type(task): - case protocol_schema.SummarizeStoryTask: - payload = db_payload.SummarizationStoryPayload(story=task.story) - - case protocol_schema.RateSummaryTask: - payload = db_payload.RateSummaryPayload( - full_text=task.full_text, summary=task.summary, scale=task.scale - ) - - case protocol_schema.InitialPromptTask: - payload = db_payload.InitialPromptPayload(hint=task.hint) - - case protocol_schema.PrompterReplyTask: - payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint) - - case protocol_schema.AssistantReplyTask: - payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) - - case protocol_schema.RankInitialPromptsTask: - payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) - - case protocol_schema.RankPrompterRepliesTask: - payload = db_payload.RankPrompterRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies - ) - - case protocol_schema.RankAssistantRepliesTask: - payload = db_payload.RankAssistantRepliesPayload( - type=task.type, conversation=task.conversation, replies=task.replies - ) - - case protocol_schema.LabelInitialPromptTask: - payload = db_payload.LabelInitialPromptPayload( - type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels - ) - - case protocol_schema.LabelPrompterReplyTask: - payload = db_payload.LabelPrompterReplyPayload( - type=task.type, - message_id=task.message_id, - conversation=task.conversation, - reply=task.reply, - valid_labels=task.valid_labels, - ) - - case protocol_schema.LabelAssistantReplyTask: - payload = db_payload.LabelAssistantReplyPayload( - type=task.type, - message_id=task.message_id, - conversation=task.conversation, - reply=task.reply, - valid_labels=task.valid_labels, - ) - - case _: - raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) - - task = self.insert_task( - payload=payload, - id=task.id, - message_tree_id=message_tree_id, - parent_message_id=parent_message_id, - collective=collective, - ) - assert task.id == task.id - return task - - def insert_task( - self, - payload: db_payload.TaskPayload, - id: UUID = None, - message_tree_id: UUID = None, - parent_message_id: UUID = None, - collective: bool = False, - ) -> Task: - c = PayloadContainer(payload=payload) - task = Task( - id=id, - user_id=self.user_id, - payload_type=type(payload).__name__, - payload=c, - api_client_id=self.api_client.id, - message_tree_id=message_tree_id, - parent_message_id=parent_message_id, - collective=collective, - ) - self.db.add(task) - self.db.commit() - self.db.refresh(task) - return task - - def insert_message( - self, - *, - message_id: UUID, - frontend_message_id: str, - parent_id: UUID, - message_tree_id: UUID, - task_id: UUID, - role: str, - payload: db_payload.MessagePayload, - payload_type: str = None, - depth: int = 0, - ) -> Message: - if payload_type is None: - if payload is None: - payload_type = "null" - else: - payload_type = type(payload).__name__ - - message = Message( - id=message_id, - parent_id=parent_id, - message_tree_id=message_tree_id, - task_id=task_id, - user_id=self.user_id, - role=role, - frontend_message_id=frontend_message_id, - api_client_id=self.api_client.id, - payload_type=payload_type, - payload=PayloadContainer(payload=payload), - depth=depth, - ) - self.db.add(message) - self.db.commit() - self.db.refresh(message) - return message - def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) @@ -515,28 +348,6 @@ class PromptRepository: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) return message - def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): - """ - Mark task as done. No further messages will be accepted for this task. - """ - self.validate_frontend_message_id(frontend_message_id) - task = self.fetch_task_by_frontend_message_id(frontend_message_id) - - if not task: - raise OasstError( - f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND - ) - if task.expired: - raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED) - if not allow_personal_tasks and not task.collective: - raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE) - if task.done: - raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE) - - task.done = True - self.db.add(task) - self.db.commit() - @staticmethod def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """ @@ -728,24 +539,3 @@ class PromptRepository: deleted=result.get(True, 0), message_trees=result.get(None, 0), ) - - def get_user_leaderboard(self, role: str) -> LeaderboardStats: - """ - Get leaderboard stats for Messages created, - separate leaderboard for prompts & assistants - - """ - query = ( - self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id)) - .join(User, User.id == Message.user_id, isouter=True) - .filter(Message.deleted is not True, Message.role == role) - .group_by(Message.user_id, User.username, User.display_name) - .order_by(func.count(Message.user_id).desc()) - ) - - result = [ - {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]} - for i, j in enumerate(query.all(), start=1) - ] - - return LeaderboardStats(leaderboard=result) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py new file mode 100644 index 00000000..15484d66 --- /dev/null +++ b/backend/oasst_backend/task_repository.py @@ -0,0 +1,199 @@ +from typing import Optional +from uuid import UUID + +import oasst_backend.models.db_payload as db_payload +from oasst_backend.models import ApiClient, Task +from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.user_repository import UserRepository +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_404_NOT_FOUND + + +def validate_frontend_message_id(message_id: str) -> None: + # TODO: Should it be replaced with fastapi/pydantic validation? + if not isinstance(message_id, str): + raise OasstError( + f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID + ) + if not message_id: + raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID) + + +class TaskRepository: + def __init__( + self, + db: Session, + api_client: ApiClient, + client_user: Optional[protocol_schema.User], + user_repository: UserRepository, + ): + self.db = db + self.api_client = api_client + self.user_repository = user_repository + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) + self.user_id = self.user.id if self.user else None + + def store_task( + self, + task: protocol_schema.Task, + message_tree_id: UUID = None, + parent_message_id: UUID = None, + collective: bool = False, + ) -> Task: + payload: db_payload.TaskPayload + match type(task): + case protocol_schema.SummarizeStoryTask: + payload = db_payload.SummarizationStoryPayload(story=task.story) + + case protocol_schema.RateSummaryTask: + payload = db_payload.RateSummaryPayload( + full_text=task.full_text, summary=task.summary, scale=task.scale + ) + + case protocol_schema.InitialPromptTask: + payload = db_payload.InitialPromptPayload(hint=task.hint) + + case protocol_schema.PrompterReplyTask: + payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint) + + case protocol_schema.AssistantReplyTask: + payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) + + case protocol_schema.RankInitialPromptsTask: + payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts) + + case protocol_schema.RankPrompterRepliesTask: + payload = db_payload.RankPrompterRepliesPayload( + type=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.RankAssistantRepliesTask: + payload = db_payload.RankAssistantRepliesPayload( + type=task.type, conversation=task.conversation, replies=task.replies + ) + + case protocol_schema.LabelInitialPromptTask: + payload = db_payload.LabelInitialPromptPayload( + type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels + ) + + case protocol_schema.LabelPrompterReplyTask: + payload = db_payload.LabelPrompterReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, + ) + + case protocol_schema.LabelAssistantReplyTask: + payload = db_payload.LabelAssistantReplyPayload( + type=task.type, + message_id=task.message_id, + conversation=task.conversation, + reply=task.reply, + valid_labels=task.valid_labels, + ) + + case _: + raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE) + + task = self.insert_task( + payload=payload, + id=task.id, + message_tree_id=message_tree_id, + parent_message_id=parent_message_id, + collective=collective, + ) + assert task.id == task.id + return task + + def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): + validate_frontend_message_id(frontend_message_id) + + # find task + task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() + if task is None: + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) + if task.expired: + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) + if task.done or task.ack is not None: + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) + + task.frontend_message_id = frontend_message_id + task.ack = True + # ToDo: check race-condition, transaction + self.db.add(task) + self.db.commit() + + def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): + """ + Mark task as done. No further messages will be accepted for this task. + """ + validate_frontend_message_id(frontend_message_id) + task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id) + + if not task: + raise OasstError( + f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND + ) + if task.expired: + raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED) + if not allow_personal_tasks and not task.collective: + raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE) + if task.done: + raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE) + + task.done = True + self.db.add(task) + self.db.commit() + + def acknowledge_task_failure(self, task_id): + # find task + task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() + if task is None: + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) + if task.expired: + raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) + if task.done or task.ack is not None: + raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) + + task.ack = False + # ToDo: check race-condition, transaction + self.db.add(task) + self.db.commit() + + def insert_task( + self, + payload: db_payload.TaskPayload, + id: UUID = None, + message_tree_id: UUID = None, + parent_message_id: UUID = None, + collective: bool = False, + ) -> Task: + c = PayloadContainer(payload=payload) + task = Task( + id=id, + user_id=self.user_id, + payload_type=type(payload).__name__, + payload=c, + api_client_id=self.api_client.id, + message_tree_id=message_tree_id, + parent_message_id=parent_message_id, + collective=collective, + ) + self.db.add(task) + self.db.commit() + self.db.refresh(task) + return task + + def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: + validate_frontend_message_id(message_id) + task = ( + self.db.query(Task) + .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id) + .one_or_none() + ) + return task diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py new file mode 100644 index 00000000..b5508899 --- /dev/null +++ b/backend/oasst_backend/user_repository.py @@ -0,0 +1,64 @@ +from typing import Optional + +from oasst_backend.models import ApiClient, Message, User +from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import LeaderboardStats +from sqlmodel import Session, func + + +class UserRepository: + def __init__(self, db: Session, api_client: ApiClient): + self.db = db + self.api_client = api_client + + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + if not client_user: + return None + user: User = ( + self.db.query(User) + .filter( + User.api_client_id == self.api_client.id, + User.username == client_user.id, + User.auth_method == client_user.auth_method, + ) + .first() + ) + if user is None: + if create_missing: + # user is unknown, create new record + user = User( + username=client_user.id, + display_name=client_user.display_name, + api_client_id=self.api_client.id, + auth_method=client_user.auth_method, + ) + self.db.add(user) + self.db.commit() + self.db.refresh(user) + elif client_user.display_name and client_user.display_name != user.display_name: + # we found the user but the display name changed + user.display_name = client_user.display_name + self.db.add(user) + self.db.commit() + return user + + def get_user_leaderboard(self, role: str) -> LeaderboardStats: + """ + Get leaderboard stats for Messages created, + separate leaderboard for prompts & assistants + + """ + query = ( + self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id)) + .join(User, User.id == Message.user_id, isouter=True) + .filter(Message.deleted is not True, Message.role == role) + .group_by(Message.user_id, User.username, User.display_name) + .order_by(func.count(Message.user_id).desc()) + ) + + result = [ + {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]} + for i, j in enumerate(query.all(), start=1) + ] + + return LeaderboardStats(leaderboard=result)