From 801ad553b846c96843d83c0fcdf3fd4ec5af9753 Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 18 Jan 2023 15:19:00 +0700 Subject: [PATCH 001/111] Allow to filter `user` by `display_name` --- backend/oasst_backend/user_repository.py | 4 +- website/package-lock.json | 45 +++++++ website/package.json | 1 + website/src/components/DataTable.tsx | 158 +++++++++++++++++++++++ website/src/components/UserTable.tsx | 133 +++++++++++++++++++ website/src/components/UsersCell.tsx | 137 -------------------- website/src/lib/oasst_api_client.ts | 10 ++ website/src/pages/admin/index.tsx | 5 +- website/src/pages/api/admin/users.ts | 13 +- 9 files changed, 361 insertions(+), 145 deletions(-) create mode 100644 website/src/components/DataTable.tsx create mode 100644 website/src/components/UserTable.tsx delete mode 100644 website/src/components/UsersCell.tsx diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 578dc5f1..c244d67f 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -161,10 +161,10 @@ class UserRepository: users = users.order_by(User.display_name) if gt: - users = users.filter(User.display_name > gt) + users = users.filter(User.id > gt) if lt: - users = users.filter(User.display_name < lt) + users = users.filter(User.id < lt).order_by(None).order_by(User.id.desc()) if limit is not None: users = users.limit(limit) diff --git a/website/package-lock.json b/website/package-lock.json index 1fa3d14d..29cd0326 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -21,6 +21,7 @@ "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", + "@tanstack/react-table": "^8.7.6", "autoprefixer": "^10.4.13", "axios": "^1.2.1", "boolean": "^3.2.0", @@ -12294,6 +12295,37 @@ "tailwindcss": ">=3.0.0 || >= 3.0.0-alpha.1" } }, + "node_modules/@tanstack/react-table": { + "version": "8.7.6", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.7.6.tgz", + "integrity": "sha512-/QijmMFeP7wDLBnr0MQ/5MlbXePbIL/1nOtkxBC9zvmBu4gDKJEDBqipUyM7Wc/iBpSd0IFyqBlvZvTPD9FYDA==", + "dependencies": { + "@tanstack/table-core": "8.7.6" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": ">=16", + "react-dom": ">=16" + } + }, + "node_modules/@tanstack/table-core": { + "version": "8.7.6", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.7.6.tgz", + "integrity": "sha512-sqiNTMzB6cpyL8DFH6/VqW48SwiflLqxQqYpo2wNock7rdVGvlm0BLNI8vZUJbr1+fmmWmHwBvi5OMgZw8n1DA==", + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, "node_modules/@testing-library/dom": { "version": "8.19.1", "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-8.19.1.tgz", @@ -46565,6 +46597,19 @@ "mini-svg-data-uri": "^1.2.3" } }, + "@tanstack/react-table": { + "version": "8.7.6", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.7.6.tgz", + "integrity": "sha512-/QijmMFeP7wDLBnr0MQ/5MlbXePbIL/1nOtkxBC9zvmBu4gDKJEDBqipUyM7Wc/iBpSd0IFyqBlvZvTPD9FYDA==", + "requires": { + "@tanstack/table-core": "8.7.6" + } + }, + "@tanstack/table-core": { + "version": "8.7.6", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.7.6.tgz", + "integrity": "sha512-sqiNTMzB6cpyL8DFH6/VqW48SwiflLqxQqYpo2wNock7rdVGvlm0BLNI8vZUJbr1+fmmWmHwBvi5OMgZw8n1DA==" + }, "@testing-library/dom": { "version": "8.19.1", "resolved": "https://registry.npmjs.org/@testing-library/dom/-/dom-8.19.1.tgz", diff --git a/website/package.json b/website/package.json index 580d0be3..6dcbb26a 100644 --- a/website/package.json +++ b/website/package.json @@ -38,6 +38,7 @@ "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", + "@tanstack/react-table": "^8.7.6", "autoprefixer": "^10.4.13", "axios": "^1.2.1", "boolean": "^3.2.0", diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx new file mode 100644 index 00000000..eafca6d1 --- /dev/null +++ b/website/src/components/DataTable.tsx @@ -0,0 +1,158 @@ +import { + Box, + Button, + Card, + CardBody, + Flex, + FormControl, + FormLabel, + Input, + Popover, + PopoverArrow, + PopoverBody, + PopoverCloseButton, + PopoverContent, + PopoverTrigger, + Spacer, + Table, + TableCaption, + TableContainer, + Tbody, + Td, + Th, + Thead, + Tr, + useDisclosure, +} from "@chakra-ui/react"; +import { ColumnDef, flexRender, getCoreRowModel, useReactTable } from "@tanstack/react-table"; +import { ChangeEvent, ReactNode } from "react"; +import { FaFilter } from "react-icons/fa"; +import { useDebouncedCallback } from "use-debounce"; + +export type DataTableColumnDef = ColumnDef & { + filterable?: boolean; +}; + +// TODO: stricter type +export type FilterItem = { + id: string; + value: string; +}; + +export type DataTableProps = { + data: T[]; + columns: DataTableColumnDef[]; + caption?: string; + filterValues?: FilterItem[]; + onNextClick?: () => void; + onPreviousClick?: () => void; + onFilterChange?: (items: FilterItem[]) => void; +}; + +export const DataTable = ({ + data, + columns, + caption, + filterValues = [], + onNextClick, + onPreviousClick, + onFilterChange, +}: DataTableProps) => { + const { getHeaderGroups, getRowModel } = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + }); + + const handleFilterChange = (value: FilterItem) => { + const idx = filterValues.findIndex((oldValue) => oldValue.id === value.id); + let newValues: FilterItem[] = []; + if (idx === -1) { + newValues = [...filterValues, value]; + } else { + newValues = filterValues.map((oldValue) => (oldValue.id === value.id ? value : oldValue)); + } + onFilterChange(newValues); + }; + return ( + + + + + + + + + + {caption} + + {getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + ))} + + ))} + + + {getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + ))} + + ))} + +
+ + {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} + {(header.column.columnDef as DataTableColumnDef).filterable && ( + value.id === header.id)?.value ?? ""} + onChange={(value) => handleFilterChange({ id: header.id, value })} + label={flexRender(header.column.columnDef.header, header.getContext())} + > + )} + +
{flexRender(cell.column.columnDef.cell, cell.getContext())}
+
+
+
+ ); +}; + +const FilterModal = ({ + label, + onChange, + value, +}: { + label: ReactNode; + onChange: (val: string) => void; + value: string; +}) => { + const { isOpen, onOpen, onClose } = useDisclosure(); + + const handleInputChange = useDebouncedCallback((e: ChangeEvent) => { + onChange(e.target.value); + }, 500); + + return ( + + + + + + + + + + {label} + + + + + + ); +}; diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx new file mode 100644 index 00000000..3b63255a --- /dev/null +++ b/website/src/components/UserTable.tsx @@ -0,0 +1,133 @@ +import { IconButton, useToast } from "@chakra-ui/react"; +import { createColumnHelper } from "@tanstack/react-table"; +import Link from "next/link"; +import { memo, useState } from "react"; +import { FaPen } from "react-icons/fa"; +import { get } from "src/lib/api"; +import type { User } from "src/types/Users"; +import useSWR from "swr"; + +import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable"; + +interface Pagination { + /** + * The user's `display_name` used for pagination. + */ + cursor: string; + + /** + * The pagination direction. + */ + direction: "forward" | "back"; +} + +const columnHelper = createColumnHelper(); + +const columns: DataTableColumnDef[] = [ + columnHelper.accessor("user_id", { + header: "ID", + }), + columnHelper.accessor("id", { + header: "Auth ID", + }), + columnHelper.accessor("auth_method", { + header: "Auth Method", + }), + { + ...columnHelper.accessor("display_name", { + header: "Name", + }), + filterable: true, + }, + columnHelper.accessor("role", { + header: "Role", + }), + columnHelper.accessor((user) => user.user_id, { + cell: ({ getValue }) => ( + } + > + ), + header: "Update", + }), +]; + +export const UserTable = memo(function UserTable() { + const toast = useToast(); + const [pagination, setPagination] = useState({ cursor: "", direction: "forward" }); + const [users, setUsers] = useState([]); + const [filterValues, setFilterValues] = useState([]); + // Fetch and save the users. + // This follows useSWR's recommendation for simple pagination: + // https://swr.vercel.app/docs/pagination#when-to-use-useswr + const display_name = filterValues.find((value) => value.id === "display_name")?.value ?? ""; + useSWR( + `/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}&display_name=${display_name}`, + get, + { + onSuccess: (data) => { + // When no more users can be found, trigger a toast to indicate why no + // changes have taken place. We have to maintain a non-empty set of + // users otherwise we can't paginate using a cursor (since we've lost the + // cursor). + if (data.length === 0) { + toast({ + title: "No more users", + status: "warning", + duration: 1000, + isClosable: true, + }); + return; + } + setUsers(data); + }, + } + ); + + const toPreviousPage = () => { + if (users.length >= 0) { + setPagination({ + cursor: users[0].user_id, + direction: "back", + }); + } else { + toast({ + title: "Can not paginate when no users are found", + status: "warning", + duration: 1000, + isClosable: true, + }); + } + }; + + const toNextPage = () => { + if (users.length >= 0) { + setPagination({ + cursor: users[users.length - 1].user_id, + direction: "forward", + }); + } else { + toast({ + title: "Can not paginate when no users are found", + status: "warning", + duration: 1000, + isClosable: true, + }); + } + }; + + return ( + + ); +}); diff --git a/website/src/components/UsersCell.tsx b/website/src/components/UsersCell.tsx deleted file mode 100644 index 99824090..00000000 --- a/website/src/components/UsersCell.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { - Button, - Flex, - Spacer, - Stack, - Table, - TableCaption, - TableContainer, - Tbody, - Td, - Th, - Thead, - Tr, - useToast, -} from "@chakra-ui/react"; -import Link from "next/link"; -import { useState } from "react"; -import { get } from "src/lib/api"; -import type { User } from "src/types/Users"; -import useSWR from "swr"; - -interface Pagination { - /** - * The user's `display_name` used for pagination. - */ - cursor: string; - - /** - * The pagination direction. - */ - direction: "forward" | "back"; -} - -/** - * Fetches users from the users api route and then presents them in a simple Chakra table. - */ -const UsersCell = () => { - const toast = useToast(); - const [pagination, setPagination] = useState({ cursor: "", direction: "forward" }); - const [users, setUsers] = useState([]); - - // Fetch and save the users. - // This follows useSWR's recommendation for simple pagination: - // https://swr.vercel.app/docs/pagination#when-to-use-useswr - useSWR(`/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}`, get, { - onSuccess: (data) => { - // When no more users can be found, trigger a toast to indicate why no - // changes have taken place. We have to maintain a non-empty set of - // users otherwise we can't paginate using a cursor (since we've lost the - // cursor). - if (data.length === 0) { - toast({ - title: "No more users", - status: "warning", - duration: 1000, - isClosable: true, - }); - return; - } - setUsers(data); - }, - }); - - const toPreviousPage = () => { - if (users.length >= 0) { - setPagination({ - cursor: users[0].display_name, - direction: "back", - }); - } else { - toast({ - title: "Can not paginate when no users are found", - status: "warning", - duration: 1000, - isClosable: true, - }); - } - }; - - const toNextPage = () => { - if (users.length >= 0) { - setPagination({ - cursor: users[users.length - 1].display_name, - direction: "forward", - }); - } else { - toast({ - title: "Can not paginate when no users are found", - status: "warning", - duration: 1000, - isClosable: true, - }); - } - }; - - // Present users in a naive table. - return ( - - - - - - - - - Users - - - - - - - - - - - - {users.map(({ id, user_id, auth_method, display_name, role }) => ( - - - - - - - - - ))} - -
IdAuth IdAuth MethodNameRoleUpdate
{user_id}{id}{auth_method}{display_name}{role} - Manage -
-
-
- ); -}; - -export default UsersCell; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index fb11adec..866b2907 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -187,6 +187,16 @@ export class OasstApiClient { return this.get(url); } + async fetch_user_by_display_name(name: string): Promise { + const params = new URLSearchParams({ + search_text: name, + }); + + const endpoint = `/api/v1/frontend_users/by_display_name`; + + return this.get(`${endpoint}?${params.toString()}`); + } + /** * Returns the `Message`s associated with `user_id` in the backend. */ diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index 9cbea222..397230bd 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -3,7 +3,7 @@ import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; import { useEffect } from "react"; import { getAdminLayout } from "src/components/Layout"; -import UsersCell from "src/components/UsersCell"; +import { UserTable } from "src/components/UserTable"; /** * Provides the admin index page that will display a list of users and give @@ -27,7 +27,6 @@ const AdminIndex = () => { } router.push("/"); }, [router, session, status]); - return ( <> @@ -37,7 +36,7 @@ const AdminIndex = () => { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
{status === "loading" ? "loading..." : }
+
{status === "loading" ? "loading..." : }
); }; diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index e600650d..5cc41354 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,11 +1,12 @@ import { withRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { BackendUser } from "src/types/Users"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 20; +const PAGE_SIZE = 1; /** * Returns a list of user results from the database when the requesting user is @@ -17,10 +18,16 @@ const PAGE_SIZE = 20; * direction. */ const handler = withRole("admin", async (req, res) => { - const { cursor, direction } = req.query; + const { cursor, direction, display_name = "" } = req.query; // First, get all the users according to the backend. - const all_users = await oasstApiClient.fetch_users(PAGE_SIZE, cursor as string, direction === "forward"); + let all_users: BackendUser[] = []; + + if (typeof display_name === "string" && display_name) { + all_users = await oasstApiClient.fetch_user_by_display_name(display_name); + } else { + all_users = await oasstApiClient.fetch_users(PAGE_SIZE, cursor as string, direction === "forward"); + } // Next, get all the users stored in the web's auth database to fetch their role. const local_user_ids = all_users.map(({ id }) => id); From 8eff6932d6867e85b53de314e744f46d3f36953c Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 18 Jan 2023 15:40:20 +0700 Subject: [PATCH 002/111] switch PAGE_SIZE back to 20 --- website/src/pages/api/admin/users.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 5cc41354..52921213 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -6,7 +6,7 @@ import { BackendUser } from "src/types/Users"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 1; +const PAGE_SIZE = 20; /** * Returns a list of user results from the database when the requesting user is From 622a4768f63415ca2b5ff1b9efa31f7e6550e2d4 Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 18 Jan 2023 16:07:54 +0700 Subject: [PATCH 003/111] fix default column --- backend/oasst_backend/user_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index c244d67f..7c46a026 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -158,7 +158,7 @@ class UserRepository: if auth_method: users = users.filter(User.auth_method == auth_method) - users = users.order_by(User.display_name) + users = users.order_by(User.id) if gt: users = users.filter(User.id > gt) From 62a203fd8c3f7712c80b841932029712a35e3b4a Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 21 Jan 2023 03:31:35 +0000 Subject: [PATCH 004/111] [feature] move data formatting into dataset, instead of collator --- .../custom_datasets/__init__.py | 2 +- .../custom_datasets/dialogue_collator.py | 12 +----- .../custom_datasets/formatting.py | 5 +++ .../custom_datasets/prompt_dialogue.py | 7 ++-- .../custom_datasets/qa_datasets.py | 38 ++++++++++--------- .../custom_datasets/summarization.py | 7 +++- .../custom_datasets/toxic_conversation.py | 10 +++-- .../custom_datasets/translation.py | 3 +- .../tests/test_datasets.py | 13 +++---- 9 files changed, 51 insertions(+), 46 deletions(-) create mode 100644 model/supervised_finetuning/custom_datasets/formatting.py diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 2e1e4b30..558ec502 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name): if dataset_name == "debate_sum": train, eval = train_val_dataset(train, val_split=0.2) else: - val_name = "validation" if dataset_name not in ["billsum"] else "test" + val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test" eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 719fa0d6..c96ed576 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -3,7 +3,6 @@ from typing import Optional, Union import numpy as np import torch -from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase @@ -23,15 +22,8 @@ class DialogueDataCollator: flatten_messages = [] label_masks = [] - for feature_one in features: - assert len(feature_one) % 2 == 0, "Number of messages must be even" - # TODO: we should push this to dataset __getitem__ - messages = [ - (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") - + x - + (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "") - for i, x in enumerate(feature_one) - ] + for messages in features: + messages = list(messages) # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation diff --git a/model/supervised_finetuning/custom_datasets/formatting.py b/model/supervised_finetuning/custom_datasets/formatting.py new file mode 100644 index 00000000..2f0adecd --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/formatting.py @@ -0,0 +1,5 @@ +QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} + + +def format_pair(pair): + return "{} {} {}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1] diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 4a1d83a3..1c823934 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -2,6 +2,7 @@ import json import os from urllib.request import urlopen +from custom_datasets.formatting import format_pair from torch.utils.data import Dataset @@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) class InstructionTuning(Dataset): @@ -101,5 +101,4 @@ class InstructionTuning(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 7d9c7f48..47b1c247 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -7,14 +7,13 @@ import re from urllib.request import urlopen import numpy as np +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") -QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} - def index_squad_v2(example): if len(example["answers"]["text"]): @@ -78,7 +77,7 @@ class QADataset(Dataset): def __getitem__(self, idx): data = self.dataset[idx] - return self.index_fn(data) + return format_pair(self.index_fn(data)) class WebGPT(Dataset): @@ -111,7 +110,7 @@ class WebGPT(Dataset): def __getitem__(self, index): question = self.index2question[index] answer = self.questions[question] - return [question, answer] + return format_pair((question, answer)) class SODA(Dataset): @@ -121,14 +120,14 @@ class SODA(Dataset): def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] - prefix = "{}{}. {}{}".format( - QA_SPECIAL_TOKENS["StartPrefix"], - data["narrative"], - "your name {}".format(play_as), - QA_SPECIAL_TOKENS["EndPrefix"], - ) question, answer = "", "" prefix, postfix = "", "" + dialogue_bg = "{}{} {}{}".format( + QA_SPECIAL_TOKENS["StartPrefix"], + data["narrative"], + "your are {}".format(play_as), + QA_SPECIAL_TOKENS["EndPrefix"], + ) previous_chat = [] for idx, convo in enumerate(data["dialogue"]): @@ -138,14 +137,20 @@ class SODA(Dataset): else: answer = convo postfix = data["speakers"][idx] + if len(question) and len(answer) and prefix != postfix and postfix == play_as: history = "".join( - ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat] + [ + "{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) + for p in previous_chat + ] ) if len(history): history += "" - pairs.append((prefix + history + question, answer)) + prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"] + pairs.append((dialogue_bg + history + prompt, answer)) previous_chat.append((question, answer)) + return pairs def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: @@ -166,8 +171,8 @@ class SODA(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + # special token added during preprocess + return self.pairs[index] class SODADialogue(Dataset): @@ -218,7 +223,7 @@ class SODADialogue(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class JokeExplaination(Dataset): @@ -253,8 +258,7 @@ class JokeExplaination(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) # https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 2a097fe7..85d21a27 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -3,6 +3,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -54,11 +55,12 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): - def __init__(self, dataset, cache_dir, split): + def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default + self.max_words = max_words def __len__(self): return len(self.dataset) @@ -72,4 +74,5 @@ class SummarizationDataset(Dataset): else: prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) - return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary) + context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt]) + return format_pair((context, summary)) diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py index 815ac722..d3433106 100644 --- a/model/supervised_finetuning/custom_datasets/toxic_conversation.py +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -4,12 +4,13 @@ """ import random +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset class ProsocialDialogueExplaination(Dataset): - name = "prosocial_explain" + name = "explain_prosocial" TEMPLATE = [ # 0 : reply or sentence of interest, 1 : reason of caution ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), @@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset): return len(self.pairs) def __getitem__(self, idx): - return self.pairs[idx] + return format_pair(self.pairs[idx]) class ProsocialDialogue(Dataset): @@ -58,11 +59,12 @@ class ProsocialDialogue(Dataset): dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: + prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"] for answer in row["rots"]: - self.pairs.append((self.PREFIX + row["context"], answer)) + self.pairs.append((self.PREFIX + prompt, answer)) def __len__(self): return len(self.pairs) def __getitem__(self, idx): - return self.pairs[idx] + return format_pair(self.pairs[idx]) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 694d31ce..008de751 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -8,6 +8,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -82,7 +83,7 @@ class TranslationPair(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class WMT2019(TranslationPair): diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 3b59f289..8d5ad08f 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] - translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") for dataset_name in translation + others + summarize_base + qa_base: @@ -31,7 +31,6 @@ def test_collate_fn(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] - trains, evals = [], [] for dataset_name in others + qa_base + summarize_base: print(dataset_name) @@ -41,10 +40,10 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - # print(batch.keys()) - # print(tokenizer.decode(batch['input_ids'][0])) - # print('-----') - # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) + print(batch.keys()) + print(tokenizer.decode(batch["input_ids"][0])) + print("-----") + print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]])) assert batch["targets"].shape[1] <= 512 dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: From 7eb5023e8222f0977ad0d56f45107342a9f3df14 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 14:34:22 +0700 Subject: [PATCH 005/111] use cursor endpoint --- website/src/components/UserTable.tsx | 63 +++++++++++++++------------- website/src/lib/oasst_api_client.ts | 51 ++++++++++++++++------ website/src/pages/api/admin/users.ts | 26 ++++++------ 3 files changed, 86 insertions(+), 54 deletions(-) diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index 3b63255a..68285bfa 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -4,6 +4,7 @@ import Link from "next/link"; import { memo, useState } from "react"; import { FaPen } from "react-icons/fa"; import { get } from "src/lib/api"; +import { FetchUsersResponse } from "src/lib/oasst_api_client"; import type { User } from "src/types/Users"; import useSWR from "swr"; @@ -58,39 +59,43 @@ const columns: DataTableColumnDef[] = [ export const UserTable = memo(function UserTable() { const toast = useToast(); const [pagination, setPagination] = useState({ cursor: "", direction: "forward" }); - const [users, setUsers] = useState([]); + const [response, setResponse] = useState, "sort_key" | "order">>({ + items: [], + }); const [filterValues, setFilterValues] = useState([]); + const handleFilterValuesChange = (values: FilterItem[]) => { + setFilterValues(values); + setPagination((old) => ({ ...old, cursor: "" })); + }; // Fetch and save the users. // This follows useSWR's recommendation for simple pagination: // https://swr.vercel.app/docs/pagination#when-to-use-useswr const display_name = filterValues.find((value) => value.id === "display_name")?.value ?? ""; - useSWR( - `/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}&display_name=${display_name}`, - get, - { - onSuccess: (data) => { - // When no more users can be found, trigger a toast to indicate why no - // changes have taken place. We have to maintain a non-empty set of - // users otherwise we can't paginate using a cursor (since we've lost the - // cursor). - if (data.length === 0) { - toast({ - title: "No more users", - status: "warning", - duration: 1000, - isClosable: true, - }); - return; - } - setUsers(data); - }, - } - ); + useSWR< + FetchUsersResponse + >(`/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}&searchDisplayName=${display_name}&sortKey=display_name`, get, { + onSuccess: (data) => { + // When no more users can be found, trigger a toast to indicate why no + // changes have taken place. We have to maintain a non-empty set of + // users otherwise we can't paginate using a cursor (since we've lost the + // cursor). + if (data.items.length === 0) { + toast({ + title: "No more users", + status: "warning", + duration: 1000, + isClosable: true, + }); + return; + } + setResponse(data); + }, + }); const toPreviousPage = () => { - if (users.length >= 0) { + if (response.items.length >= 0) { setPagination({ - cursor: users[0].user_id, + cursor: response.prev, direction: "back", }); } else { @@ -104,9 +109,9 @@ export const UserTable = memo(function UserTable() { }; const toNextPage = () => { - if (users.length >= 0) { + if (response.items.length >= 0) { setPagination({ - cursor: users[users.length - 1].user_id, + cursor: response.next, direction: "forward", }); } else { @@ -121,13 +126,13 @@ export const UserTable = memo(function UserTable() { return ( ); }); diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 7db6e3c2..50adf267 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,7 +1,7 @@ import type { Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import type { AvailableTasks } from "src/types/Task"; -import type { BackendUser, BackendUserCore } from "src/types/Users"; +import type { BackendUser, BackendUserCore, User } from "src/types/Users"; export class OasstError { message: string; @@ -15,6 +15,22 @@ export class OasstError { } } +export type FetchUsersParams = { + limit: number; + cursor?: string; + direction: "forward" | "back"; + searchDisplayName?: string; + sortKey?: "username" | "display_name"; +}; + +export type FetchUsersResponse = { + items: T[]; + next?: string; + prev?: string; + sort_key: "username" | "display_name"; + order: "asc" | "desc"; +}; + export class OasstApiClient { oasstApiUrl: string; oasstApiKey: string; @@ -164,30 +180,39 @@ export class OasstApiClient { * forward. If false and `cursor` is not empty, pages backwards. * @returns {Promise} A Promise that returns an array of `BackendUser` objects. */ - async fetch_users(max_count: number, cursor: string, isForward: boolean): Promise { - const params = new URLSearchParams(); - params.append("max_count", max_count.toString()); + async fetch_users({ + direction, + limit, + cursor, + searchDisplayName, + sortKey = "display_name", + }: FetchUsersParams): Promise { + const params = new URLSearchParams({ + search_text: searchDisplayName, + sort_key: sortKey, + max_count: limit.toString(), + }); // The backend API uses different query parameters depending on the // pagination direction but they both take the same cursor value. // Depending on direction, pick the right query param. if (cursor !== "") { - params.append(isForward ? "gt" : "lt", cursor); + params.append(direction === "forward" ? "gt" : "lt", cursor); } - const BASE_URL = `/api/v1/frontend_users`; + const BASE_URL = `/api/v1/users/cursor`; const url = `${BASE_URL}/?${params.toString()}`; return this.get(url); } - async fetch_user_by_display_name(name: string): Promise { - const params = new URLSearchParams({ - search_text: name, - }); + // async fetch_user_by_display_name(name: string): Promise { + // const params = new URLSearchParams({ + // search_text: name, + // }); - const endpoint = `/api/v1/frontend_users/by_display_name`; + // const endpoint = `/api/v1/frontend_users/by_display_name`; - return this.get(`${endpoint}?${params.toString()}`); - } + // return this.get(`${endpoint}?${params.toString()}`); + // } /** * Returns the `Message`s associated with `user_id` in the backend. diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 52921213..f43af305 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,12 +1,11 @@ import { withRole } from "src/lib/auth"; -import { oasstApiClient } from "src/lib/oasst_api_client"; +import { FetchUsersParams, oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; -import { BackendUser } from "src/types/Users"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 20; +const PAGE_SIZE = 2; /** * Returns a list of user results from the database when the requesting user is @@ -18,16 +17,16 @@ const PAGE_SIZE = 20; * direction. */ const handler = withRole("admin", async (req, res) => { - const { cursor, direction, display_name = "" } = req.query; + const { cursor, direction, searchDisplayName = "", sortKey = "username" } = req.query; // First, get all the users according to the backend. - let all_users: BackendUser[] = []; - - if (typeof display_name === "string" && display_name) { - all_users = await oasstApiClient.fetch_user_by_display_name(display_name); - } else { - all_users = await oasstApiClient.fetch_users(PAGE_SIZE, cursor as string, direction === "forward"); - } + const { items: all_users, ...rest } = await oasstApiClient.fetch_users({ + searchDisplayName: searchDisplayName as FetchUsersParams["searchDisplayName"], + direction: direction as FetchUsersParams["direction"], + limit: PAGE_SIZE, + cursor: cursor as FetchUsersParams["cursor"], + sortKey: sortKey === "username" || sortKey === "display_name" ? sortKey : undefined, + }); // Next, get all the users stored in the web's auth database to fetch their role. const local_user_ids = all_users.map(({ id }) => id); @@ -58,7 +57,10 @@ const handler = withRole("admin", async (req, res) => { }; }); - res.status(200).json(users); + res.status(200).json({ + items: users, + ...rest, + }); }); export default handler; From 77210ee6d41100a912498dbe898901f6d605ffa5 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 14:37:05 +0700 Subject: [PATCH 006/111] remove debug code --- website/src/pages/api/admin/users.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index f43af305..57944cff 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -5,7 +5,7 @@ import prisma from "src/lib/prismadb"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 2; +const PAGE_SIZE = 20; /** * Returns a list of user results from the database when the requesting user is From 7274512c2f408355370d7b36779bd78690b6bbfb Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sat, 21 Jan 2023 21:36:37 +0900 Subject: [PATCH 007/111] Implementing core of setting user languages and fetching language specific tasks --- website/next-i18next.config.js | 2 +- website/package-lock.json | 45 +++++++++++++++++++ website/package.json | 2 + website/src/components/Header/Header.tsx | 2 + .../LanguageSelector/LanguageSelector.tsx | 40 +++++++++++++++++ .../src/components/LanguageSelector/index.tsx | 1 + website/src/lib/oasst_api_client.ts | 7 ++- website/src/lib/users.ts | 16 ++++++- website/src/pages/api/new_task/[task_type].ts | 5 ++- website/src/pages/api/update_task.ts | 13 +++++- 10 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 website/src/components/LanguageSelector/LanguageSelector.tsx create mode 100644 website/src/components/LanguageSelector/index.tsx diff --git a/website/next-i18next.config.js b/website/next-i18next.config.js index 7c87a7a4..40c4b14d 100644 --- a/website/next-i18next.config.js +++ b/website/next-i18next.config.js @@ -1,6 +1,6 @@ module.exports = { i18n: { defaultLocale: "en", - locales: ["en"], + locales: ["de", "en", "fr"], }, }; diff --git a/website/package-lock.json b/website/package-lock.json index 06f3c98d..5c5dc795 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -21,6 +21,7 @@ "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", + "accept-language-parser": "^1.5.0", "autoprefixer": "^10.4.13", "axios": "^1.2.1", "boolean": "^3.2.0", @@ -38,6 +39,7 @@ "npm": "^9.2.0", "postcss-focus-visible": "^7.1.0", "react": "18.2.0", + "react-cookies": "^0.1.1", "react-dom": "18.2.0", "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", @@ -13616,6 +13618,11 @@ "integrity": "sha512-j2afSsaIENvHZN2B8GOpF566vZ5WVk5opAiMTvWgaQT8DkbOqsTfvNAvHoRGU2zzP8cPoqys+xHTRDWW8L+/BA==", "dev": true }, + "node_modules/accept-language-parser": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/accept-language-parser/-/accept-language-parser-1.5.0.tgz", + "integrity": "sha512-QhyTbMLYo0BBGg1aWbeMG4ekWtds/31BrEU+DONOg/7ax23vxpL03Pb7/zBmha2v7vdD3AyzZVWBVGEZxKOXWw==" + }, "node_modules/accepts": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", @@ -32466,6 +32473,23 @@ "react": "^15.3.0 || ^16.0.0 || ^17.0.0 || ^18.0.0" } }, + "node_modules/react-cookies": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/react-cookies/-/react-cookies-0.1.1.tgz", + "integrity": "sha512-PP75kJ4vtoHuuTdq0TAD3RmlAv7vuDQh9fkC4oDlhntgs9vX1DmREomO0Y1mcQKR9nMZ6/zxoflaMJ3MAmF5KQ==", + "dependencies": { + "cookie": "^0.3.1", + "object-assign": "^4.1.1" + } + }, + "node_modules/react-cookies/node_modules/cookie": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", + "integrity": "sha512-+IJOX0OqlHCszo2mBUq+SrEbCj6w7Kpffqx60zYbPTFaO4+yYgRjHwcZNpWvaTylDHaV7PPmBHzSecZiMhtPgw==", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/react-docgen": { "version": "5.4.3", "resolved": "https://registry.npmjs.org/react-docgen/-/react-docgen-5.4.3.tgz", @@ -47817,6 +47841,11 @@ "integrity": "sha512-j2afSsaIENvHZN2B8GOpF566vZ5WVk5opAiMTvWgaQT8DkbOqsTfvNAvHoRGU2zzP8cPoqys+xHTRDWW8L+/BA==", "dev": true }, + "accept-language-parser": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/accept-language-parser/-/accept-language-parser-1.5.0.tgz", + "integrity": "sha512-QhyTbMLYo0BBGg1aWbeMG4ekWtds/31BrEU+DONOg/7ax23vxpL03Pb7/zBmha2v7vdD3AyzZVWBVGEZxKOXWw==" + }, "accepts": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", @@ -61962,6 +61991,22 @@ "@babel/runtime": "^7.12.13" } }, + "react-cookies": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/react-cookies/-/react-cookies-0.1.1.tgz", + "integrity": "sha512-PP75kJ4vtoHuuTdq0TAD3RmlAv7vuDQh9fkC4oDlhntgs9vX1DmREomO0Y1mcQKR9nMZ6/zxoflaMJ3MAmF5KQ==", + "requires": { + "cookie": "^0.3.1", + "object-assign": "^4.1.1" + }, + "dependencies": { + "cookie": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", + "integrity": "sha512-+IJOX0OqlHCszo2mBUq+SrEbCj6w7Kpffqx60zYbPTFaO4+yYgRjHwcZNpWvaTylDHaV7PPmBHzSecZiMhtPgw==" + } + } + }, "react-docgen": { "version": "5.4.3", "resolved": "https://registry.npmjs.org/react-docgen/-/react-docgen-5.4.3.tgz", diff --git a/website/package.json b/website/package.json index 4ae762e4..8866a9e2 100644 --- a/website/package.json +++ b/website/package.json @@ -38,6 +38,7 @@ "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", + "accept-language-parser": "^1.5.0", "autoprefixer": "^10.4.13", "axios": "^1.2.1", "boolean": "^3.2.0", @@ -55,6 +56,7 @@ "npm": "^9.2.0", "postcss-focus-visible": "^7.1.0", "react": "18.2.0", + "react-cookies": "^0.1.1", "react-dom": "18.2.0", "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index a1b36123..64614578 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -5,6 +5,7 @@ import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import { Flags } from "react-feature-flags"; import { FaUser } from "react-icons/fa"; +import { LanguageSelector } from "src/components/LanguageSelector"; import { UserMenu } from "./UserMenu"; @@ -45,6 +46,7 @@ export function Header() { FlagTest + diff --git a/website/src/components/LanguageSelector/LanguageSelector.tsx b/website/src/components/LanguageSelector/LanguageSelector.tsx new file mode 100644 index 00000000..e611bf0f --- /dev/null +++ b/website/src/components/LanguageSelector/LanguageSelector.tsx @@ -0,0 +1,40 @@ +import { Select } from "@chakra-ui/react"; +import { useRouter } from "next/router"; +import { useTranslation } from "next-i18next"; +import { useCallback, useMemo, useState } from "react"; +import cookie from "react-cookies"; + +const LanguageSelector = () => { + const router = useRouter(); + const { i18n } = useTranslation(); + + const { language: currentLanguage } = i18n; + const languageNames = useMemo(() => { + return new Intl.DisplayNames([currentLanguage], { + type: "language", + }); + }, [currentLanguage]); + + const languageChanged = useCallback( + async (option) => { + const locale = option.target.value; + cookie.save("NEXT_LOCALE", locale, { path: "/" }); + const path = router.asPath; + return router.push(path, path, { locale }); + }, + [router] + ); + + const locales = router.locales; + return ( + + ); +}; + +export { LanguageSelector }; diff --git a/website/src/components/LanguageSelector/index.tsx b/website/src/components/LanguageSelector/index.tsx new file mode 100644 index 00000000..feb9f322 --- /dev/null +++ b/website/src/components/LanguageSelector/index.tsx @@ -0,0 +1 @@ +export * from "./LanguageSelector"; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 1e74b020..36e3ae33 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -108,10 +108,11 @@ export class OasstApiClient { // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. // This is a raw Json type, so we can't use it to strongly type the task. - async fetchTask(taskType: string, user: BackendUserCore): Promise { + async fetchTask(taskType: string, user: BackendUserCore, lang: string): Promise { return this.post("/api/v1/tasks/", { type: taskType, user, + lang, }); } @@ -136,7 +137,8 @@ export class OasstApiClient { messageId: string, userMessageId: string, content: object, - user: BackendUserCore + user: BackendUserCore, + lang: string ): Promise { return this.post("/api/v1/tasks/interaction", { type: updateType, @@ -144,6 +146,7 @@ export class OasstApiClient { task_id: taskId, message_id: messageId, user_message_id: userMessageId, + lang, ...content, }); } diff --git a/website/src/lib/users.ts b/website/src/lib/users.ts index 2aa8c708..3dbe5a08 100644 --- a/website/src/lib/users.ts +++ b/website/src/lib/users.ts @@ -1,6 +1,20 @@ +import parser from "accept-language-parser"; +import type { NextApiRequest } from "next"; import prisma from "src/lib/prismadb"; import type { BackendUserCore } from "src/types/Users"; +const getUserLanguage = (req: NextApiRequest) => { + const cookieLanguage = req.cookies["NEXT_LOCALE"]; + if (cookieLanguage) { + return cookieLanguage; + } + const headerLanguages = parser.parse(req.headers["accept-language"]); + if (headerLanguages.length > 0) { + return headerLanguages[0].code; + } + return "en"; +}; + /** * Returns a `BackendUserCore` that can be used for interacting with the Backend service. * @@ -35,4 +49,4 @@ const getBackendUserCore = async (id: string) => { } as BackendUserCore; }; -export { getBackendUserCore }; +export { getBackendUserCore, getUserLanguage }; diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index c8255b18..360b8faa 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,7 +1,7 @@ import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; -import { getBackendUserCore } from "src/lib/users"; +import { getBackendUserCore, getUserLanguage } from "src/lib/users"; /** * Returns a new task created from the Task Backend. We do a few things here: @@ -14,11 +14,12 @@ import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { // Fetch the new task. const { task_type } = req.query; + const userLanguage = getUserLanguage(req); const user = await getBackendUserCore(token.sub); let task; try { - task = await oasstApiClient.fetchTask(task_type as string, user); + task = await oasstApiClient.fetchTask(task_type as string, user, userLanguage); } catch (err) { console.error(err); res.status(500).json(err); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index c547503a..6f08d640 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -2,7 +2,7 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; -import { getBackendUserCore } from "src/lib/users"; +import { getBackendUserCore, getUserLanguage } from "src/lib/users"; /** * Stores the task interaction with the Task Backend and then returns the next task generated. @@ -41,9 +41,18 @@ const handler = withoutRole("banned", async (req, res, token) => { }); const user = await getBackendUserCore(token.sub); + const userLanguage = getUserLanguage(req); let newTask; try { - newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, user); + newTask = await oasstApiClient.interactTask( + update_type, + taskId, + frontendId, + interaction.id, + content, + user, + userLanguage + ); } catch (err) { console.error(JSON.stringify(err)); return res.status(500).json(err); From 27e1e549c42e830f6b560a52c3b595eddf7cc6d3 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 20:10:19 +0700 Subject: [PATCH 008/111] fix query in backward direction --- backend/oasst_backend/user_repository.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 1e9ac78f..e960d944 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -200,6 +200,7 @@ class UserRepository: search_text: Optional[str] = None, limit: Optional[int] = 100, ) -> list[User]: + if not self.api_client.trusted: if not api_client_id: # Let unprivileged api clients query their own users without api_client_id being set @@ -226,11 +227,15 @@ class UserRepository: if lte_display_name is not None: if lt_id: - qry = qry.filter( - or_( - User.display_name < lte_display_name, - and_(User.display_name == lte_display_name, User.id < lt_id), + qry = ( + qry.filter( + or_( + User.display_name < lte_display_name, + and_(User.display_name == lte_display_name, User.id < lt_id), + ) ) + .order_by(None) + .order_by(User.display_name.desc(), User.id.desc()) ) else: qry = qry.filter(User.display_name <= lte_display_name) @@ -252,4 +257,9 @@ class UserRepository: if limit is not None: qry = qry.limit(limit) - return qry.all() + users = qry.all() + + if lte_display_name and lt_id: + users.reverse() + + return users From c1dd188cbea8e88a945b8dc190c5e32dc1872c90 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 23:12:19 +0700 Subject: [PATCH 009/111] use pre and next cursor check from the server --- website/src/components/DataTable.tsx | 94 +++++++++++++++------------- website/src/components/UserTable.tsx | 94 ++++++++++------------------ website/src/pages/api/admin/users.ts | 2 +- 3 files changed, 84 insertions(+), 106 deletions(-) diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index eafca6d1..1784650a 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -1,8 +1,6 @@ import { Box, Button, - Card, - CardBody, Flex, FormControl, FormLabel, @@ -47,6 +45,8 @@ export type DataTableProps = { onNextClick?: () => void; onPreviousClick?: () => void; onFilterChange?: (items: FilterItem[]) => void; + disableNext?: boolean; + disablePrevious?: boolean; }; export const DataTable = ({ @@ -57,6 +57,8 @@ export const DataTable = ({ onNextClick, onPreviousClick, onFilterChange, + disableNext, + disablePrevious, }: DataTableProps) => { const { getHeaderGroups, getRowModel } = useReactTable({ data, @@ -75,49 +77,51 @@ export const DataTable = ({ onFilterChange(newValues); }; return ( - - - - - - - - - - {caption} - - {getHeaderGroups().map((headerGroup) => ( - - {headerGroup.headers.map((header) => ( - - ))} - - ))} - - - {getRowModel().rows.map((row) => ( - - {row.getVisibleCells().map((cell) => ( - - ))} - - ))} - -
- - {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} - {(header.column.columnDef as DataTableColumnDef).filterable && ( - value.id === header.id)?.value ?? ""} - onChange={(value) => handleFilterChange({ id: header.id, value })} - label={flexRender(header.column.columnDef.header, header.getContext())} - > - )} - -
{flexRender(cell.column.columnDef.cell, cell.getContext())}
-
-
-
+ <> + + + + + + + + {caption} + + {getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + ))} + + ))} + + + {getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + ))} + + ))} + +
+ + {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} + {(header.column.columnDef as DataTableColumnDef).filterable && ( + value.id === header.id)?.value ?? ""} + onChange={(value) => handleFilterChange({ id: header.id, value })} + label={flexRender(header.column.columnDef.header, header.getContext())} + > + )} + +
{flexRender(cell.column.columnDef.cell, cell.getContext())}
+
+ ); }; diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index 68285bfa..1f4ccfca 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -1,4 +1,4 @@ -import { IconButton, useToast } from "@chakra-ui/react"; +import { Card, CardBody, IconButton } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; import Link from "next/link"; import { memo, useState } from "react"; @@ -57,11 +57,7 @@ const columns: DataTableColumnDef[] = [ ]; export const UserTable = memo(function UserTable() { - const toast = useToast(); const [pagination, setPagination] = useState({ cursor: "", direction: "forward" }); - const [response, setResponse] = useState, "sort_key" | "order">>({ - items: [], - }); const [filterValues, setFilterValues] = useState([]); const handleFilterValuesChange = (values: FilterItem[]) => { setFilterValues(values); @@ -71,68 +67,46 @@ export const UserTable = memo(function UserTable() { // This follows useSWR's recommendation for simple pagination: // https://swr.vercel.app/docs/pagination#when-to-use-useswr const display_name = filterValues.find((value) => value.id === "display_name")?.value ?? ""; - useSWR< - FetchUsersResponse - >(`/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}&searchDisplayName=${display_name}&sortKey=display_name`, get, { - onSuccess: (data) => { - // When no more users can be found, trigger a toast to indicate why no - // changes have taken place. We have to maintain a non-empty set of - // users otherwise we can't paginate using a cursor (since we've lost the - // cursor). - if (data.items.length === 0) { - toast({ - title: "No more users", - status: "warning", - duration: 1000, - isClosable: true, - }); - return; - } - setResponse(data); - }, - }); + const { data, error } = useSWR>( + `/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}&searchDisplayName=${display_name}&sortKey=display_name`, + get, + { + keepPreviousData: true, + } + ); const toPreviousPage = () => { - if (response.items.length >= 0) { - setPagination({ - cursor: response.prev, - direction: "back", - }); - } else { - toast({ - title: "Can not paginate when no users are found", - status: "warning", - duration: 1000, - isClosable: true, - }); - } + setPagination({ + cursor: data.prev, + direction: "back", + }); }; const toNextPage = () => { - if (response.items.length >= 0) { - setPagination({ - cursor: response.next, - direction: "forward", - }); - } else { - toast({ - title: "Can not paginate when no users are found", - status: "warning", - duration: 1000, - isClosable: true, - }); - } + setPagination({ + cursor: data.next, + direction: "forward", + }); }; return ( - + + + {data && ( + + )} + {error && "Unable to load users."} + + ); }); diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 57944cff..f43af305 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -5,7 +5,7 @@ import prisma from "src/lib/prismadb"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 20; +const PAGE_SIZE = 2; /** * Returns a list of user results from the database when the requesting user is From aebfaacac8e5abb8a5b5097ea6f30d68afd9763c Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 23:21:10 +0700 Subject: [PATCH 010/111] remove debug code --- website/src/pages/api/admin/users.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index f43af305..57944cff 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -5,7 +5,7 @@ import prisma from "src/lib/prismadb"; /** * The number of users to fetch in a single request. Could later be a query parameter. */ -const PAGE_SIZE = 2; +const PAGE_SIZE = 20; /** * Returns a list of user results from the database when the requesting user is From 15acd1c64e917910e68b32c800f23ff516655d43 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 21 Jan 2023 23:56:21 +0700 Subject: [PATCH 011/111] handle error --- website/src/components/DataTable.tsx | 94 +++++++++++++++------------- website/src/components/UserTable.tsx | 34 +++++----- 2 files changed, 64 insertions(+), 64 deletions(-) diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index 1784650a..f9ef4e49 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -1,6 +1,8 @@ import { Box, Button, + Card, + CardBody, Flex, FormControl, FormLabel, @@ -77,51 +79,53 @@ export const DataTable = ({ onFilterChange(newValues); }; return ( - <> - - - - - - - - {caption} - - {getHeaderGroups().map((headerGroup) => ( - - {headerGroup.headers.map((header) => ( - - ))} - - ))} - - - {getRowModel().rows.map((row) => ( - - {row.getVisibleCells().map((cell) => ( - - ))} - - ))} - -
- - {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} - {(header.column.columnDef as DataTableColumnDef).filterable && ( - value.id === header.id)?.value ?? ""} - onChange={(value) => handleFilterChange({ id: header.id, value })} - label={flexRender(header.column.columnDef.header, header.getContext())} - > - )} - -
{flexRender(cell.column.columnDef.cell, cell.getContext())}
-
- + + + + + + + + + + {caption} + + {getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + ))} + + ))} + + + {getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + ))} + + ))} + +
+ + {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} + {(header.column.columnDef as DataTableColumnDef).filterable && ( + value.id === header.id)?.value ?? ""} + onChange={(value) => handleFilterChange({ id: header.id, value })} + label={flexRender(header.column.columnDef.header, header.getContext())} + > + )} + +
{flexRender(cell.column.columnDef.cell, cell.getContext())}
+
+
+
); }; diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index 1f4ccfca..df412bbc 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -1,4 +1,4 @@ -import { Card, CardBody, IconButton } from "@chakra-ui/react"; +import { IconButton } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; import Link from "next/link"; import { memo, useState } from "react"; @@ -90,23 +90,19 @@ export const UserTable = memo(function UserTable() { }; return ( - - - {data && ( - - )} - {error && "Unable to load users."} - - + <> + + {error && "Unable to load users."} + ); }); From f5b2a348577d93ea45f2f5ff186ef8f20543083e Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sun, 22 Jan 2023 00:56:17 +0000 Subject: [PATCH 012/111] [feature] add pythia and limit translation pair --- model/supervised_finetuning/custom_datasets/translation.py | 2 ++ model/supervised_finetuning/utils.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 008de751..f9a71a8e 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -100,6 +100,8 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) + if len(self.pairs) > 100000: + break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 7b6e03b6..f7a0ab15 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -25,6 +25,10 @@ def get_tokenizer(conf): tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"}) elif "codegen" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) + elif "pythia" in conf.model_name: + tokenizer.add_special_tokens( + {"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"} + ) additional_special_tokens = ( [] From ab09a3f50fe17242c32c7e82f2b94e409561e73e Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 22 Jan 2023 16:13:35 +0900 Subject: [PATCH 013/111] Adding a valid language check --- website/src/lib/users.ts | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/website/src/lib/users.ts b/website/src/lib/users.ts index 3dbe5a08..3f1cb65d 100644 --- a/website/src/lib/users.ts +++ b/website/src/lib/users.ts @@ -3,13 +3,26 @@ import type { NextApiRequest } from "next"; import prisma from "src/lib/prismadb"; import type { BackendUserCore } from "src/types/Users"; +import { i18n } from "src/../next-i18next.config"; + +const LOCALE_SET = new Set(i18n.locales); + +/** + * Returns the most appropriate user language using the following priority: + * + * 1. The `NEXT_LOCALE` cookie which is set by the client side and will be in + * the set of supported locales. + * 2. The `accept-language` header if it contains a supported locale as set by + * the i18n module. + * 3. "en" as a final fallback. + */ const getUserLanguage = (req: NextApiRequest) => { const cookieLanguage = req.cookies["NEXT_LOCALE"]; if (cookieLanguage) { return cookieLanguage; } const headerLanguages = parser.parse(req.headers["accept-language"]); - if (headerLanguages.length > 0) { + if (headerLanguages.length > 0 && LOCALE_SET.has(headerLanguages[0].code)) { return headerLanguages[0].code; } return "en"; From 6945cc5fe7c7028901c0834ea8d24a9ef628c268 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 14:14:41 +0700 Subject: [PATCH 014/111] remove `reverse` method --- backend/oasst_backend/user_repository.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 670bbc3e..118f4e82 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -269,9 +269,4 @@ class UserRepository: if limit is not None: qry = qry.limit(limit) - users = qry.all() - - if lte_display_name and lt_id: - users.reverse() - - return users + return qry.all() From 79331df366f3a95bb5209b210552189457ab9723 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 22 Jan 2023 16:15:54 +0900 Subject: [PATCH 015/111] Fixing an import order --- website/src/components/LanguageSelector/LanguageSelector.tsx | 2 +- website/src/lib/users.ts | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/website/src/components/LanguageSelector/LanguageSelector.tsx b/website/src/components/LanguageSelector/LanguageSelector.tsx index e611bf0f..37659265 100644 --- a/website/src/components/LanguageSelector/LanguageSelector.tsx +++ b/website/src/components/LanguageSelector/LanguageSelector.tsx @@ -1,7 +1,7 @@ import { Select } from "@chakra-ui/react"; import { useRouter } from "next/router"; import { useTranslation } from "next-i18next"; -import { useCallback, useMemo, useState } from "react"; +import { useCallback, useMemo } from "react"; import cookie from "react-cookies"; const LanguageSelector = () => { diff --git a/website/src/lib/users.ts b/website/src/lib/users.ts index 3f1cb65d..637e93a9 100644 --- a/website/src/lib/users.ts +++ b/website/src/lib/users.ts @@ -1,10 +1,9 @@ import parser from "accept-language-parser"; import type { NextApiRequest } from "next"; +import { i18n } from "src/../next-i18next.config"; import prisma from "src/lib/prismadb"; import type { BackendUserCore } from "src/types/Users"; -import { i18n } from "src/../next-i18next.config"; - const LOCALE_SET = new Set(i18n.locales); /** From 101f2c536a3ebbfe744637fc9bace22768755eb3 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 14:20:39 +0700 Subject: [PATCH 016/111] revert change in user_repository --- backend/oasst_backend/user_repository.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 118f4e82..79df99ab 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -233,15 +233,11 @@ class UserRepository: if lte_display_name is not None: if lt_id: - qry = ( - qry.filter( - or_( - User.display_name < lte_display_name, - and_(User.display_name == lte_display_name, User.id < lt_id), - ) + qry = qry.filter( + or_( + User.display_name < lte_display_name, + and_(User.display_name == lte_display_name, User.id < lt_id), ) - .order_by(None) - .order_by(User.display_name.desc(), User.id.desc()) ) else: qry = qry.filter(User.display_name <= lte_display_name) From 3b5b6669a503114019f08ffe09c605851c90b304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 22 Jan 2023 08:42:38 +0100 Subject: [PATCH 017/111] move lt-desc order to users-cursor function --- backend/oasst_backend/api/v1/users.py | 14 ++++++++++---- backend/oasst_backend/user_repository.py | 24 ++++++++++++------------ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index ed07eb21..c7ff9f9c 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -28,6 +28,7 @@ def get_users_ordered_by_username( search_text: Optional[str] = None, auth_method: Optional[str] = None, max_count: Optional[int] = Query(100, gt=0, le=10000), + desc: Optional[bool] = False, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): @@ -41,6 +42,7 @@ def get_users_ordered_by_username( auth_method=auth_method, search_text=search_text, limit=max_count, + desc=desc, ) return [u.to_protocol_frontend_user() for u in users] @@ -55,6 +57,7 @@ def get_users_ordered_by_display_name( auth_method: Optional[str] = None, search_text: Optional[str] = None, max_count: Optional[int] = Query(100, gt=0, le=10000), + desc: Optional[bool] = False, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): @@ -68,6 +71,7 @@ def get_users_ordered_by_display_name( auth_method=auth_method, search_text=search_text, limit=max_count, + desc=desc, ) return [u.to_protocol_frontend_user() for u in users] @@ -96,6 +100,7 @@ def get_users_cursor( items: list[protocol.FrontEndUser] qry_max_count = max_count + 1 if lt is None or gt is None else max_count + desc = lt and not gt def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]): p, n = None, None @@ -115,10 +120,9 @@ def get_users_cursor( num_rows = len(items) if qry_max_count > max_count and num_rows == qry_max_count: assert not (lt and gt) - if lt: - items = items[1:] - else: - items = items[:-1] + items = items[:-1] + if desc: + items.reverse() return items, num_rows n, p = None, None @@ -134,6 +138,7 @@ def get_users_cursor( auth_method=auth_method, search_text=search_text, max_count=qry_max_count, + desc=desc, api_client=api_client, db=db, ) @@ -152,6 +157,7 @@ def get_users_cursor( auth_method=auth_method, search_text=search_text, max_count=qry_max_count, + desc=desc, api_client=api_client, db=db, ) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 44a7e685..b467bcaf 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -145,6 +145,7 @@ class UserRepository: auth_method: Optional[str] = None, search_text: Optional[str] = None, limit: Optional[int] = 100, + desc: bool = False, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -184,14 +185,13 @@ class UserRepository: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.username.like(pattern)) - if limit is not None and lte_username and not gte_username: - # select top rows but return results in ascernding order - sub_qry = qry.order_by(User.username.desc(), User.id.desc()).limit(limit).subquery("u") - qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.username, User.id) + if desc: + qry = qry.order_by(User.username.desc(), User.id.desc()) else: qry = qry.order_by(User.username, User.id) - if limit is not None: - qry = qry.limit(limit) + + if limit is not None: + qry = qry.limit(limit) return qry.all() @@ -205,6 +205,7 @@ class UserRepository: auth_method: Optional[str] = None, search_text: Optional[str] = None, limit: Optional[int] = 100, + desc: bool = False, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -255,13 +256,12 @@ class UserRepository: if auth_method: qry = qry.filter(User.auth_method == auth_method) - if limit is not None and lte_display_name and not gte_display_name: - # select top rows but return results in ascernding order - sub_qry = qry.order_by(User.display_name.desc(), User.id.desc()).limit(limit).subquery("u") - qry = self.db.query(User).select_entity_from(sub_qry).order_by(User.display_name, User.id) + if desc: + qry = qry.order_by(User.display_name.desc(), User.id.desc()) else: qry = qry.order_by(User.display_name, User.id) - if limit is not None: - qry = qry.limit(limit) + + if limit is not None: + qry = qry.limit(limit) return qry.all() From 28089d9ecf6766086298c4a1034f06046062b9cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 22 Jan 2023 09:29:21 +0100 Subject: [PATCH 018/111] fix username+auth combo check --- backend/oasst_backend/prompt_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index abf5b721..8c259dda 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -693,7 +693,7 @@ class PromptRepository: if user_id: qry = qry.filter(Message.user_id == user_id) if username or auth_method: - if not username and auth_method: + if not (username and auth_method): raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED) qry = qry.join(User) qry = qry.filter(User.username == username, User.auth_method == auth_method) From 952c61f4613a74e387c1819b56ea7a0fd12e9c44 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 15:50:23 +0700 Subject: [PATCH 019/111] Fix recent messages --- website/src/pages/api/messages/user.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index e5f361b8..bd651acc 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -4,6 +4,7 @@ const handler = withoutRole("banned", async (req, res, token) => { //TODO: add params if needed const params = new URLSearchParams({ username: token.sub, + auth_method: "local", }); const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { From 554e730d348f766e49af2e71760c303757919815 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 16:01:01 +0700 Subject: [PATCH 020/111] user `getBackendUserCore` --- website/src/pages/api/messages/user.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index bd651acc..cbf658f1 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -1,10 +1,12 @@ import { withoutRole } from "src/lib/auth"; +import { getBackendUserCore } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { //TODO: add params if needed + const user = await getBackendUserCore(token.sub); const params = new URLSearchParams({ username: token.sub, - auth_method: "local", + auth_method: user.auth_method, }); const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, { From 0f0d0e00b5494775355035aa9c5cad7cd170d7d9 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 17:05:22 +0700 Subject: [PATCH 021/111] use `user.id` --- website/src/pages/api/messages/user.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/pages/api/messages/user.ts b/website/src/pages/api/messages/user.ts index cbf658f1..6f39aad1 100644 --- a/website/src/pages/api/messages/user.ts +++ b/website/src/pages/api/messages/user.ts @@ -5,7 +5,7 @@ const handler = withoutRole("banned", async (req, res, token) => { //TODO: add params if needed const user = await getBackendUserCore(token.sub); const params = new URLSearchParams({ - username: token.sub, + username: user.id, auth_method: user.auth_method, }); From c0391a6df9aff9bf83e2baa1b59d0d5478ad434b Mon Sep 17 00:00:00 2001 From: James Melvin Ebenezer Date: Sun, 22 Jan 2023 15:38:02 +0530 Subject: [PATCH 022/111] fix: redundant row updates with no Task id in text_labels table (#876) * fix: redundant row updates with no Task id in text_labels table * fix: review comments incorporated * fix: better error handling and function name * fix: review comments Co-authored-by: James Melvin --- backend/oasst_backend/api/v1/text_labels.py | 16 +++++++++++----- backend/oasst_backend/prompt_repository.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index c9afd88c..dc6cc889 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -4,8 +4,9 @@ from loguru import logger from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function +from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -14,20 +15,25 @@ router = APIRouter() @router.post("/", status_code=HTTP_204_NO_CONTENT) def label_text( *, - db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), text_labels: protocol_schema.TextLabels, ) -> None: """ Label a piece of text. """ - api_client = deps.api_auth(api_key, db) + + @managed_tx_function(CommitMode.COMMIT) + def store_text_labels(session: deps.Session): + api_client = deps.api_auth(api_key, session) + pr = PromptRepository(session, api_client, client_user=text_labels.user) + pr.store_text_labels(text_labels) try: logger.info(f"Labeling text {text_labels=}.") - pr = PromptRepository(db, api_client, client_user=text_labels.user) - pr.store_text_labels(text_labels) + store_text_labels() + except OasstError: + raise except Exception: logger.exception("Failed to store label.") raise HTTPException( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 8c259dda..0a0fa61d 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -448,6 +448,11 @@ class PromptRepository: if task: message.review_count += 1 self.db.add(message) + # for the same User id with no task id associated with the message, then update existing record for repeated updates + existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) + if existing_text_label is not None: + existing_text_label.labels = text_labels.labels + model = existing_text_label self.db.add(model) return model, task, message @@ -561,6 +566,16 @@ class PromptRepository: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) return message + def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]: + + query = ( + self.db.query(TextLabels) + .outerjoin(Task, Task.id == TextLabels.id) + .filter(Task.id.is_(None), TextLabels.message_id == message_id, TextLabels.user_id == user_id) + ) + text_label = query.one_or_none() + return text_label + @staticmethod def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]: """ From 6167f63467e0bc98a644b52f9dd57090e1b55287 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 22 Jan 2023 11:53:34 +0100 Subject: [PATCH 023/111] Present available task types on dashboard --- .../src/components/Dashboard/TaskOption.tsx | 77 +++++++++++-------- website/src/components/Tasks/Task/Task.tsx | 4 +- website/src/components/Tasks/TaskTypes.tsx | 8 +- website/src/pages/dashboard.tsx | 10 ++- 4 files changed, 57 insertions(+), 42 deletions(-) diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index e2bafac3..5a759d40 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -1,48 +1,61 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Link from "next/link"; +import { useMemo } from "react"; +import { TaskType } from "src/types/Task"; -import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes"; +import { TaskCategory, TaskCategoryLabels, TaskInfo, TaskInfos } from "../Tasks/TaskTypes"; -export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => { +export interface TasksOptionProps { + content: Partial>; +} + +export const TaskOption = ({ content }: TasksOptionProps) => { const backgroundColor = useColorModeValue("white", "gray.700"); + const taskInfoMap = useMemo( + () => + Object.values(content) + .flat() + .reduce((obj, taskType) => { + obj[taskType] = TaskInfos.filter((t) => t.type === taskType).pop(); + return obj; + }, {} as Record), + [content] + ); + return ( - {displayTaskCategories.map((category) => ( + {Object.entries(content).map(([category, taskTypes]) => (
- {TaskCategoryLabels[category]} + + {TaskCategoryLabels[category]} + - {TaskTypes.filter((task) => task.category === category).map((item) => ( - - - - - - {item.label} - - - {item.desc} - - - - taskInfoMap[taskType]) + .map((item) => ( + + - + + {item.label} + {item.desc} + + Go -> - - - - ))} + + + ))}
))} diff --git a/website/src/components/Tasks/Task/Task.tsx b/website/src/components/Tasks/Task/Task.tsx index 3d393575..45fb83d0 100644 --- a/website/src/components/Tasks/Task/Task.tsx +++ b/website/src/components/Tasks/Task/Task.tsx @@ -3,7 +3,7 @@ import { TaskControls } from "src/components/Survey/TaskControls"; import { CreateTask } from "src/components/Tasks/CreateTask"; import { EvaluateTask } from "src/components/Tasks/EvaluateTask"; import { LabelTask } from "src/components/Tasks/LabelTask"; -import { TaskCategory, TaskInfo, TaskTypes } from "src/components/Tasks/TaskTypes"; +import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskTypes"; import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning"; import { post } from "src/lib/api"; import { TaskContent } from "src/types/Task"; @@ -29,7 +29,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { const rootEl = useRef(null); - const taskType = TaskTypes.find((taskType) => taskType.type === task.type && taskType.mode === task.mode); + const taskType = TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode); const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, { onSuccess: async () => { diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 4c6da92c..d10159d9 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -21,16 +21,16 @@ export interface TaskInfo { } export const TaskCategoryLabels: { [key in TaskCategory]: string } = { - [TaskCategory.Random]: "I'm feeling lucky", + [TaskCategory.Random]: "Grab a task!", [TaskCategory.Create]: "Create", [TaskCategory.Evaluate]: "Evaluate", [TaskCategory.Label]: "Label", }; -export const TaskTypes: TaskInfo[] = [ +export const TaskInfos: TaskInfo[] = [ // general/random { - label: "Start a Task", + label: "I'm feeling lucky", desc: "Help us improve Open Assistant by starting a random task.", category: TaskCategory.Random, pathname: "/tasks/random", @@ -104,7 +104,7 @@ export const TaskTypes: TaskInfo[] = [ category: TaskCategory.Evaluate, pathname: "/evaluate/rank_initial_prompts", help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting", - overview: "Given the following inital prompts, sort them from best to worst, best being first, worst being last.", + overview: "Given the following initial prompts, sort them from best to worst, best being first, worst being last.", type: "rank_initial_prompts", update_type: "message_ranking", unchanged_title: "Order Unchanged", diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index e0b8bba4..4def6196 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -5,15 +5,17 @@ import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashbo import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; import { get } from "src/lib/api"; -import type { AvailableTasks, TaskType } from "src/types/Task"; +import { AvailableTasks, TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import useSWRImmutable from "swr/immutable"; const Dashboard = () => { const { data } = useSWRImmutable("/api/available_tasks", get); - // TODO: show only these tasks: - const availableTasks = useMemo(() => filterAvailableTasks(data ?? {}), [data]); + const availableTaskTypes = useMemo(() => { + const taskTypes = filterAvailableTasks(data ?? {}); + return { [TaskCategory.Random]: taskTypes }; + }, [data]); return ( <> @@ -23,7 +25,7 @@ const Dashboard = () => { - + From fd703663bd1bea984a09869da6124d4dd49641a2 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 22 Jan 2023 12:00:55 +0100 Subject: [PATCH 024/111] Fix all tasks page --- website/src/components/Dashboard/TaskOption.tsx | 11 +++++++++++ website/src/pages/tasks/all.tsx | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index 5a759d40..e73a06c8 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -62,3 +62,14 @@ export const TaskOption = ({ content }: TasksOptionProps) => {
); }; + +export const allTaskOptions: TasksOptionProps["content"] = { + [TaskCategory.Random]: [TaskType.random], + [TaskCategory.Create]: [TaskType.initial_prompt, TaskType.prompter_reply, TaskType.assistant_reply], + [TaskCategory.Evaluate]: [ + TaskType.rank_initial_prompts, + TaskType.rank_prompter_replies, + TaskType.rank_assistant_replies, + ], + [TaskCategory.Label]: [TaskType.label_initial_prompt, TaskType.label_prompter_reply, TaskType.label_assistant_reply], +}; diff --git a/website/src/pages/tasks/all.tsx b/website/src/pages/tasks/all.tsx index 3ccfd4e8..01954c2f 100644 --- a/website/src/pages/tasks/all.tsx +++ b/website/src/pages/tasks/all.tsx @@ -1,7 +1,7 @@ import Head from "next/head"; import { TaskOption } from "src/components/Dashboard"; +import { allTaskOptions } from "src/components/Dashboard/TaskOption"; import { getDashboardLayout } from "src/components/Layout"; -import { TaskCategory } from "src/components/Tasks/TaskTypes"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const AllTasks = () => { @@ -11,7 +11,7 @@ const AllTasks = () => { All Tasks - Open Assistant - + ); }; From fb4e94487cd407c941096400e10c4704c88a6913 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 22 Jan 2023 12:15:21 +0100 Subject: [PATCH 025/111] Redirect users to dashboard if there are no tasks --- website/src/components/EmptyState.tsx | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index 14715518..51e51a00 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -1,5 +1,5 @@ -import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react"; -import { useRouter } from "next/router"; +import { Box, Text, useColorModeValue } from "@chakra-ui/react"; +import NextLink from "next/link"; import { FiAlertTriangle } from "react-icons/fi"; import { IconType } from "react-icons/lib"; @@ -10,16 +10,15 @@ type EmptyStateProps = { export const EmptyState = (props: EmptyStateProps) => { const backgroundColor = useColorModeValue("white", "gray.800"); - const router = useRouter(); return ( {props.text} - router.back()} color="blue.500" textUnderlineOffset="3px"> - Click here to go back - + + Go back to the dashboard + ); From d466e63d66ef9081b981c7a0598f7ec24cff867c Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 22 Jan 2023 20:35:31 +0700 Subject: [PATCH 026/111] wip --- website/package-lock.json | 15 +++++++++++++++ website/package.json | 1 + website/src/components/CallToAction.tsx | 5 +++-- website/src/components/EmptyState.tsx | 7 +++---- website/src/components/Header/Header.tsx | 4 ++-- website/src/components/Header/UserMenu.tsx | 14 +++++++------- website/src/components/Layout.tsx | 12 ++++++------ website/src/components/SideMenu.tsx | 9 ++++----- website/src/components/Sortable/SortableItem.tsx | 4 ++-- website/src/components/Survey/TaskControls.tsx | 10 ++++++++-- .../components/Tasks/TaskHeader/TaskHeader.tsx | 4 ++-- website/src/pages/404.tsx | 6 +++--- website/src/pages/500.tsx | 9 +++------ website/src/pages/account/index.tsx | 6 +++--- 14 files changed, 62 insertions(+), 44 deletions(-) diff --git a/website/package-lock.json b/website/package-lock.json index 5c5dc795..65fde28e 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -32,6 +32,7 @@ "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", "install": "^0.13.0", + "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", "next-i18next": "^13.0.3", @@ -26694,6 +26695,14 @@ "yallist": "^3.0.2" } }, + "node_modules/lucide-react": { + "version": "0.105.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz", + "integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==", + "peerDependencies": { + "react": "^16.5.1 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/lz-string": { "version": "1.4.4", "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz", @@ -57869,6 +57878,12 @@ "yallist": "^3.0.2" } }, + "lucide-react": { + "version": "0.105.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz", + "integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==", + "requires": {} + }, "lz-string": { "version": "1.4.4", "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz", diff --git a/website/package.json b/website/package.json index 8866a9e2..c0195163 100644 --- a/website/package.json +++ b/website/package.json @@ -49,6 +49,7 @@ "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", "install": "^0.13.0", + "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", "next-i18next": "^13.0.3", diff --git a/website/src/components/CallToAction.tsx b/website/src/components/CallToAction.tsx index e374a471..ddeed6b7 100644 --- a/website/src/components/CallToAction.tsx +++ b/website/src/components/CallToAction.tsx @@ -1,7 +1,8 @@ import { Box, Link, Text, useColorMode } from "@chakra-ui/react"; +import { Github } from "lucide-react"; import { useTranslation } from "next-i18next"; import { useId } from "react"; -import { FaDiscord, FaGithub } from "react-icons/fa"; +import { FaDiscord } from "react-icons/fa"; import { Container } from "./Container"; @@ -81,7 +82,7 @@ export function CallToAction() { type="button" className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2" > - + {t("github")} diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index 14715518..d2acea7f 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -1,11 +1,10 @@ import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react"; +import { AlertTriangle, LucideIcon } from "lucide-react"; import { useRouter } from "next/router"; -import { FiAlertTriangle } from "react-icons/fi"; -import { IconType } from "react-icons/lib"; type EmptyStateProps = { text: string; - icon: IconType; + icon: LucideIcon; }; export const EmptyState = (props: EmptyStateProps) => { @@ -26,5 +25,5 @@ export const EmptyState = (props: EmptyStateProps) => { }; export const TaskEmptyState = () => { - return ; + return ; }; diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index 64614578..0d70a442 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -1,10 +1,10 @@ import { Box, Button, Flex, Text } from "@chakra-ui/react"; +import { User } from "lucide-react"; import Image from "next/image"; import Link from "next/link"; import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import { Flags } from "react-feature-flags"; -import { FaUser } from "react-icons/fa"; import { LanguageSelector } from "src/components/LanguageSelector"; import { UserMenu } from "./UserMenu"; @@ -17,7 +17,7 @@ function AccountButton() { return ( - diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 6fdde69e..bc2a377e 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -11,11 +11,11 @@ import { Text, useColorModeValue, } from "@chakra-ui/react"; +import { AlertTriangle, Layout, LogOut, Settings, Shield } from "lucide-react"; import NextLink from "next/link"; import { signOut, useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import React, { ElementType, useCallback } from "react"; -import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi"; interface MenuOption { name: string; @@ -41,21 +41,21 @@ export function UserMenu() { name: t("dashboard"), href: "/dashboard", desc: t("dashboard"), - icon: FiLayout, + icon: Layout, isExternal: false, }, { name: t("account_settings"), href: "/account", desc: t("account_settings"), - icon: FiSettings, + icon: Settings, isExternal: false, }, { name: t("report_a_bug"), href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose", desc: t("report_a_bug"), - icon: FiAlertTriangle, + icon: AlertTriangle, isExternal: true, }, ]; @@ -65,7 +65,7 @@ export function UserMenu() { name: t("admin_dashboard"), href: "/admin", desc: t("admin_dashboard"), - icon: FiShield, + icon: Shield, isExternal: false, }); } @@ -98,7 +98,7 @@ export function UserMenu() { _hover={{ textDecoration: "none" }} > - @@ -106,7 +106,7 @@ export function UserMenu() { - diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 55085550..1b5bf430 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -1,8 +1,8 @@ // https://nextjs.org/docs/basic-features/layouts import { Box, Grid } from "@chakra-ui/react"; +import { Activity, BarChart2, Layout, MessageSquare, Users } from "lucide-react"; import type { NextPage } from "next"; -import { FiBarChart2, FiLayout, FiMessageSquare, FiUsers, FiActivity } from "react-icons/fi"; import { Header } from "src/components/Header"; import { SlimFooter } from "./Dashboard/SlimFooter"; @@ -38,19 +38,19 @@ export const getDashboardLayout = (page: React.ReactElement) => ( label: "Dashboard", pathname: "/dashboard", desc: "Dashboard Home", - icon: FiLayout, + icon: Layout, }, { label: "Messages", pathname: "/messages", desc: "Messages Dashboard", - icon: FiMessageSquare, + icon: MessageSquare, }, { label: "Leaderboard", pathname: "/leaderboard", desc: "User Leaderboard", - icon: FiBarChart2, + icon: BarChart2, }, ]} > @@ -73,13 +73,13 @@ export const getAdminLayout = (page: React.ReactElement) => ( label: "Users", pathname: "/admin", desc: "Users Dashboard", - icon: FiUsers, + icon: Users, }, { label: "Status", pathname: "/admin/status", desc: "Status Dashboard", - icon: FiActivity, + icon: Activity, }, ]} > diff --git a/website/src/components/SideMenu.tsx b/website/src/components/SideMenu.tsx index 3722eaa8..10e83ce4 100644 --- a/website/src/components/SideMenu.tsx +++ b/website/src/components/SideMenu.tsx @@ -1,15 +1,14 @@ import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react"; +import { LucideIcon, Sun } from "lucide-react"; import Link from "next/link"; import { useRouter } from "next/router"; -import { FiSun } from "react-icons/fi"; -import { IconType } from "react-icons/lib"; import { colors } from "styles/Theme/colors"; export interface MenuButtonOption { label: string; pathname: string; desc: string; - icon: IconType; + icon: LucideIcon; } export interface SideMenuProps { @@ -47,7 +46,7 @@ export function SideMenu(props: SideMenuProps) { bg={router.pathname === item.pathname ? "blue.500" : null} _hover={router.pathname === item.pathname ? { bg: "blue.600" } : null} > - + diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 7e28f2c2..58c7559f 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -22,8 +22,8 @@ import { } from "@chakra-ui/react"; import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; import clsx from "clsx"; +import { AlertCircle } from "lucide-react"; import { useEffect, useReducer } from "react"; -import { FiAlertCircle } from "react-icons/fi"; import { get, post } from "src/lib/api"; import { colors } from "src/styles/Theme/colors"; import { Message } from "src/types/Conversation"; @@ -154,7 +154,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => { - diff --git a/website/src/components/Icons/Discord.tsx b/website/src/components/Icons/Discord.tsx new file mode 100644 index 00000000..ea1118fb --- /dev/null +++ b/website/src/components/Icons/Discord.tsx @@ -0,0 +1,16 @@ +import { LucideIcon } from "lucide-react"; + +export const Discord: LucideIcon = ({ size = 24, ...rest }) => { + return ( + + + + ); +}; diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index df412bbc..5e5828ea 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -1,8 +1,8 @@ import { IconButton } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; +import { Pencil } from "lucide-react"; import Link from "next/link"; import { memo, useState } from "react"; -import { FaPen } from "react-icons/fa"; import { get } from "src/lib/api"; import { FetchUsersResponse } from "src/lib/oasst_api_client"; import type { User } from "src/types/Users"; @@ -49,7 +49,7 @@ const columns: DataTableColumnDef[] = [ as={Link} href={`/admin/manage_user/${getValue()}`} aria-label="Manage" - icon={} + icon={} > ), header: "Update", diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index e3757190..d171182d 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -1,17 +1,18 @@ import { Button, ButtonProps, Input, Stack, useColorModeValue } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; +import { Bug, Github, Mail } from "lucide-react"; import { GetServerSideProps } from "next"; import Head from "next/head"; import Link from "next/link"; import { useRouter } from "next/router"; import { ClientSafeProvider, getProviders, signIn } from "next-auth/react"; import { serverSideTranslations } from "next-i18next/serverSideTranslations"; -import React, { useEffect, useRef, useState } from "react"; +import React, { useEffect, useState } from "react"; import { useForm } from "react-hook-form"; -import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; import { AuthLayout } from "src/components/AuthLayout"; import { Footer } from "src/components/Footer"; import { Header } from "src/components/Header"; +import { Discord } from "src/components/Icons/Discord"; import { Role, RoleSelect } from "src/components/RoleSelect"; export type SignInErrorTypes = @@ -89,7 +90,7 @@ function Signin({ providers }: SigninProps) { placeholder="Email Address" {...register("email")} /> - }> + }> Continue with Email @@ -103,7 +104,7 @@ function Signin({ providers }: SigninProps) { bg: "#454FBF", }} size="lg" - leftIcon={} + leftIcon={} color="white" onClick={() => signIn(discord.id, { callbackUrl: "/" })} // isDisabled="false" @@ -119,7 +120,7 @@ function Signin({ providers }: SigninProps) { bg: "#101010", }} size={"lg"} - leftIcon={} + leftIcon={} colorScheme="blue" // isDisabled="false" > @@ -165,7 +166,7 @@ const SigninButton = (props: ButtonProps) => { return ( + ); + })} + + ); +}; diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index a9f29bc2..b0455774 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -5,13 +5,14 @@ import NextLink from "next/link"; type EmptyStateProps = { text: string; icon: LucideIcon; + "data-cy"?: string; }; export const EmptyState = (props: EmptyStateProps) => { const backgroundColor = useColorModeValue("white", "gray.800"); return ( - + {props.text} @@ -24,5 +25,5 @@ export const EmptyState = (props: EmptyStateProps) => { }; export const TaskEmptyState = () => { - return ; + return ; }; diff --git a/website/src/components/Explain.tsx b/website/src/components/Explain.tsx new file mode 100644 index 00000000..b571757f --- /dev/null +++ b/website/src/components/Explain.tsx @@ -0,0 +1,39 @@ +import { + IconButton, + Popover, + PopoverArrow, + PopoverBody, + PopoverCloseButton, + PopoverContent, + PopoverTrigger, + Text, +} from "@chakra-ui/react"; +import { InformationCircleIcon } from "@heroicons/react/20/solid"; + +interface ExplainProps { + explanation: string[]; +} + +export const Explain = ({ explanation }: ExplainProps) => { + return ( + + + } + > + + + + + + {explanation.map((paragraph, idx) => ( + {paragraph} + ))} + + + + ); +}; diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 58c7559f..0bebce97 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -1,127 +1,69 @@ import { Box, Button, - Checkbox, - Flex, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, Popover, PopoverAnchor, - PopoverArrow, - PopoverBody, - PopoverCloseButton, - PopoverContent, PopoverTrigger, - Slider, - SliderFilledTrack, - SliderThumb, - SliderTrack, Tooltip, - useBoolean, - useColorMode, useColorModeValue, - useId, + useDisclosure, } from "@chakra-ui/react"; -import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; -import clsx from "clsx"; import { AlertCircle } from "lucide-react"; -import { useEffect, useReducer } from "react"; +import { useState } from "react"; import { get, post } from "src/lib/api"; import { colors } from "src/styles/Theme/colors"; import { Message } from "src/types/Conversation"; -import useSWR from "swr"; +import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; +import { LabelInputGroup } from "./Survey/LabelInputGroup"; + interface Label { name: string; display_text: string; help_text: string; } -interface LoadLabelsAction { - type: "load_labels"; - labels: Label[]; -} - -interface UpdateValueAction { - type: "update_value"; - label_index: number; - value: number; -} - -interface ToggleLabelAction { - type: "toggle_label"; - label_index: number; - check: boolean; -} - -interface LabelValue { - label: Label; - checked: boolean; - value: number; -} - -interface FlagReportState { - label_values: LabelValue[]; - submittable: boolean; -} - interface FlaggableElementProps { children: React.ReactNode; message: Message; } +interface ValidLabelsResponse { + valid_labels: Label[]; +} + export const FlaggableElement = (props: FlaggableElementProps) => { - const [report, updateReport] = useReducer( - (state: FlagReportState, action: LoadLabelsAction | UpdateValueAction | ToggleLabelAction): FlagReportState => { - const makeState = (label_values: LabelValue[]): FlagReportState => { - const submittable = label_values.map(({ checked }) => checked).some(Boolean); - return { label_values, submittable }; - }; + const { data: response } = useSWRImmutable("/api/valid_labels", get); + const { isOpen, onOpen, onClose } = useDisclosure(); + const { valid_labels } = response || { valid_labels: [] }; + const [values, setValues] = useState([]); - switch (action.type) { - case "load_labels": - return makeState( - action.labels.map((label) => { - return { label, checked: false, value: 1 }; - }) - ); - case "toggle_label": { - const values_copy = state.label_values.slice(); - values_copy[action.label_index].checked = action.check; - return makeState(values_copy); - } - case "update_value": { - const values_copy = state.label_values.slice(); - values_copy[action.label_index].value = action.value; - return makeState(values_copy); - } - } - }, - { label_values: [], submittable: false } - ); - const [isEditing, setIsEditing] = useBoolean(); - - const { data, isLoading } = useSWR("/api/valid_labels", get); - useEffect(() => { - if (isLoading) { - return; - } - if (!data) { - updateReport({ type: "load_labels", labels: [] }); - return; - } - const { valid_labels } = data; - updateReport({ type: "load_labels", labels: valid_labels }); - }, [data, isLoading]); + const submittable = + values.some((value) => { + return value !== null; + }) && + values.length === valid_labels.length && + valid_labels.length > 0; const { trigger } = useSWRMutation("/api/set_label", post, { - onSuccess: setIsEditing.off, + onSuccess: onClose, + onError: onClose, }); const submitResponse = () => { const label_map: Map = new Map(); - report.label_values.forEach(({ label, checked, value }) => { - if (checked) { - label_map.set(label.name, value); + console.assert(valid_labels.length === values.length); + values.forEach((value, idx) => { + if (value !== null) { + label_map.set(valid_labels[idx].name, value); } }); trigger({ @@ -131,22 +73,8 @@ export const FlaggableElement = (props: FlaggableElementProps) => { }); }; - const handleCheckboxState = (checked, label_index) => { - updateReport({ type: "toggle_label", label_index, check: checked }); - }; - const handleSliderState = (value, label_index) => { - updateReport({ type: "update_value", label_index, value }); - }; - return ( - + {props.children} @@ -161,26 +89,17 @@ export const FlaggableElement = (props: FlaggableElementProps) => { - - - - - - - {report.label_values.map(({ label, checked, value }, i) => ( - - ))} - + + + + Select one or more labels that apply. + + + name)} onChange={setValues} /> + + - - - + + + ); }; - -interface FlagCheckboxProps { - label: Label; - idx: number; - checked: boolean; - sliderValue: number; - checkboxHandler: (newVal: boolean, idx: number) => void; - sliderHandler: (newVal: number, idx: number) => void; -} - -export function FlagCheckbox(props: FlagCheckboxProps): JSX.Element { - let AdditionalExplanation = null; - if (props.label.help_text) { - AdditionalExplanation = ( - - - ); - } - - const id = useId(); - const { colorMode } = useColorMode(); - - const labelTextClass = - colorMode === "light" - ? `text-${colors.light.text} hover:text-blue-700` - : `text-${colors.dark.text} hover:text-blue-400`; - - return ( - -
- { - props.checkboxHandler(e.target.checked, props.idx); - }} - /> - -
-
{ - if (!props.checked) { - props.checkboxHandler(true, props.idx); - } - }} - > - { - props.sliderHandler(val / 100, props.idx); - }} - > - - - - - -
-
- ); -} diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelInputGroup.tsx new file mode 100644 index 00000000..08d8a628 --- /dev/null +++ b/website/src/components/Survey/LabelInputGroup.tsx @@ -0,0 +1,195 @@ +import { Box, Flex, Grid, Spacer, Text, useColorModeValue, VStack } from "@chakra-ui/react"; +import React from "react"; +import { useState } from "react"; +import { TbChevronLeft, TbChevronRight, TbChevronsLeft, TbChevronsRight } from "react-icons/tb"; +import { LikertButtons } from "src/components/Buttons/LikertButtons"; +import { Explain } from "src/components/Explain"; + +interface LabelInputGroupProps { + labelIDs: Array; + simple?: boolean; + onChange: (values: number[]) => unknown; + isEditable?: boolean; +} + +interface LabelInfo { + zeroText: string; + oneText: string; + zeroDescription: string[]; + oneDescription: string[]; + inverted: boolean; +} + +// This should be moved to the valid labels api endpoint +const label_messages: { + [label: string]: LabelInfo; +} = { + spam: { + zeroText: "Suitable for Training", + zeroDescription: ["Suitable for training Open Assistant."], + oneText: "Spam", + oneDescription: [ + "Seems to be intentionally low-quality or irrelevant", + 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', + "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", + "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", + ], + inverted: true, + }, + fails_task: { + zeroText: "Follows Instructions", + zeroDescription: [], + oneText: "Fails Task", + oneDescription: ["Fails to follow the correct instruction / task"], + inverted: true, + }, + not_appropriate: { + zeroText: "Appropriate", + zeroDescription: [], + oneText: "Inappropriate", + oneDescription: ["Inappropriate for customer assistant"], + inverted: true, + }, + violence: { + zeroText: "Harmless", + zeroDescription: [], + oneText: "Encourages Violence", + oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], + inverted: true, + }, + excessive_harm: { + zeroText: "Safe", + zeroDescription: [], + oneText: "Causes Harm", + oneDescription: [ + "Content likely to cause excessive harm not justifiable in the context", + "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", + ], + inverted: true, + }, + sexual_content: { + zeroText: "Non Sexual", + zeroDescription: [], + oneText: "Sexual Content", + oneDescription: ["Contains sexual content"], + inverted: true, + }, + toxicity: { + zeroText: "Non Toxic", + zeroDescription: [], + oneText: "Rude / Toxic", + oneDescription: ["Contains rude, abusive, profane or insulting content"], + inverted: true, + }, + moral_judgement: { + zeroText: "Non-Judgemental", + zeroDescription: [], + oneText: "Judgemental", + oneDescription: ["Expresses moral judgement"], + inverted: true, + }, + political_content: { + zeroText: "Apolitical", + zeroDescription: [], + oneText: "Political", + oneDescription: ["Expresses political views"], + inverted: true, + }, + humor: { + zeroText: "Serious", + zeroDescription: [], + oneText: "Humorous / Sarcastic", + oneDescription: ["Contains humorous content including sarcasm"], + inverted: false, + }, + hate_speech: { + zeroText: "Safe", + zeroDescription: [], + oneText: "Hateful", + oneDescription: [ + "Content is abusive or threatening and expresses prejudice against a protected characteristic", + "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + ], + inverted: true, + }, + threat: { + zeroText: "Safe", + zeroDescription: [], + oneText: "Contains Threat", + oneDescription: ["Contains a threat against a person or persons"], + inverted: true, + }, + misleading: { + zeroText: "Accurate", + zeroDescription: [], + oneText: "Misleading", + oneDescription: ["Contains text which is incorrect or misleading"], + inverted: true, + }, + helpful: { + zeroText: "Unhelful", + zeroDescription: [], + oneText: "Helpful", + oneDescription: ["Completes the task to a high standard"], + inverted: false, + }, + creative: { + zeroText: "Boring", + zeroDescription: [], + oneText: "Creative", + oneDescription: ["Expresses creativity in responding to the task"], + inverted: false, + }, +}; + +export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => { + const [labelValues, setLabelValues] = useState(Array.from({ length: labelIDs.length }).map(() => null)); + + const cardColor = useColorModeValue("gray.50", "gray.800"); + + return ( + + {labelIDs.map((labelId, idx) => { + const { zeroText, oneText, zeroDescription, oneDescription, inverted } = label_messages[labelId]; + + let textA = zeroText; + let textB = oneText; + let descriptionA = zeroDescription; + let descriptionB = oneDescription; + if (inverted) [textA, textB, descriptionA, descriptionB] = [textB, textA, descriptionB, descriptionA]; + + return ( + + + + {textA} + {descriptionA.length > 0 ? : null} + + {textB} + {descriptionB.length > 0 ? : null} + + , + , + "", + , + , + ]} + data-cy="label-options" + value={labelValues[idx] === null ? null : inverted ? 1 - labelValues[idx] : labelValues[idx]} + onChange={(value) => { + const newState = labelValues.slice(); + newState[idx] = value === null ? null : inverted ? 1 - value : value; + onChange(newState); + setLabelValues(newState); + }} + /> + + + ); + })} + + ); +}; diff --git a/website/src/components/Survey/LabelRadioGroup.tsx b/website/src/components/Survey/LabelRadioGroup.tsx deleted file mode 100644 index c4a5a51c..00000000 --- a/website/src/components/Survey/LabelRadioGroup.tsx +++ /dev/null @@ -1,129 +0,0 @@ -import { - Box, - Button, - Flex, - IconButton, - Popover, - PopoverArrow, - PopoverBody, - PopoverCloseButton, - PopoverContent, - PopoverTrigger, - Text, - useColorMode, -} from "@chakra-ui/react"; -import { InformationCircleIcon } from "@heroicons/react/20/solid"; -import { useId, useState } from "react"; -import { colors } from "src/styles/Theme/colors"; - -interface LabelRadioGroupProps { - labelIDs: Array; - onChange: (sliderValues: number[]) => unknown; - isEditable?: boolean; -} - -const label_messages: { [label: string]: { description: string; explanation: string[] } } = { - spam: { - description: "Is the message spam?", - explanation: [ - 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', - "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", - "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", - ], - }, -}; - -export const LabelRadioGroup = (props: LabelRadioGroupProps) => { - const [labelValues, setLabelValues] = useState(Array.from({ length: props.labelIDs.length }).map(() => 0)); - const [interactionFlag, setInteractionFlag] = useState(false); - - return ( - - {props.labelIDs.map((labelId, idx) => ( - { - const newState = labelValues.slice(); - newState[idx] = newValue; - props.onChange(newState); - setLabelValues(newState); - if (!interactionFlag) setInteractionFlag(true); - }} - states={[ - { text: "No", value: 0 }, - { text: "Yes", value: 1 }, - ]} - isEditable={props.isEditable} - interactionFlag={interactionFlag} - /> - ))} - - ); -}; - -interface ButtonState { - text: string; - value: number; - colorScheme?: string; -} - -interface LabelRadioItemProps { - labelText: { description: string; explanation?: string[] }; - labelValue: number; - clickHandler: (newVal: number) => unknown; - states: ButtonState[]; - isEditable: boolean; - interactionFlag: boolean; -} - -const LabelRadioItem = (props: LabelRadioItemProps) => { - const id = useId(); - const { colorMode } = useColorMode(); - - const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; - - return ( - - - - {props.states.map((item, idx) => ( - - ))} - - - ); -}; diff --git a/website/src/components/Survey/LabelSliderGroup.tsx b/website/src/components/Survey/LabelSliderGroup.tsx deleted file mode 100644 index 1c3b29b5..00000000 --- a/website/src/components/Survey/LabelSliderGroup.tsx +++ /dev/null @@ -1,67 +0,0 @@ -import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack, useColorMode } from "@chakra-ui/react"; -import { useId, useState } from "react"; -import { colors } from "src/styles/Theme/colors"; - -// TODO: consolidate with FlaggableElement -interface LabelSliderGroupProps { - labelIDs: Array; - onChange: (sliderValues: number[]) => unknown; - isEditable?: boolean; -} - -export const LabelSliderGroup = ({ labelIDs, onChange, isEditable }: LabelSliderGroupProps) => { - const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0)); - - return ( - - {labelIDs.map((labelId, idx) => ( - { - const newState = sliderValues.slice(); - newState[idx] = sliderValue; - onChange(newState); - setSliderValues(newState); - }} - isEditable={isEditable} - /> - ))} - - ); -}; - -function CheckboxSliderItem(props: { - labelId: string; - sliderValue: number; - sliderHandler: (newVal: number) => unknown; - isEditable: boolean; -}) { - const id = useId(); - const { colorMode } = useColorMode(); - - const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; - - return ( - <> - - props.sliderHandler(val / 100)} - > - - - - - - - ); -} diff --git a/website/src/components/Survey/TaskControls.tsx b/website/src/components/Survey/TaskControls.tsx index 81c8df2a..76aaee8b 100644 --- a/website/src/components/Survey/TaskControls.tsx +++ b/website/src/components/Survey/TaskControls.tsx @@ -47,7 +47,7 @@ export const TaskControls = (props: TaskControlsProps) => { Submit @@ -59,7 +59,7 @@ export const TaskControls = (props: TaskControlsProps) => { Review diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 6cbead52..289bd892 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -12,6 +12,7 @@ export const CreateTask = ({ isEditable, isDisabled, onReplyChanged, + onValidityChanged, }: TaskSurveyProps<{ text: string }>) => { const cardColor = useColorModeValue("gray.50", "gray.800"); const titleColor = useColorModeValue("gray.800", "gray.300"); @@ -20,11 +21,12 @@ export const CreateTask = ({ const textChangeHandler = (event: React.ChangeEvent) => { const text = event.target.value; const isTextBlank = !text || /^\s*$/.test(text) ? true : false; + onReplyChanged({ text }); if (!isTextBlank) { - onReplyChanged({ content: { text }, state: "VALID" }); + onValidityChanged("VALID"); setInputText(text); } else { - onReplyChanged({ content: { text }, state: "INVALID" }); + onValidityChanged("INVALID"); setInputText(""); } }; diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 6ec92a96..4b43e35e 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -1,5 +1,5 @@ import { Box, useColorModeValue } from "@chakra-ui/react"; -import { useEffect } from "react"; +import { useEffect, useState } from "react"; import { MessageTable } from "src/components/Messages/MessageTable"; import { Sortable } from "src/components/Sortable/Sortable"; import { SurveyCard } from "src/components/Survey/SurveyCard"; @@ -12,8 +12,10 @@ export const EvaluateTask = ({ isEditable, isDisabled, onReplyChanged, + onValidityChanged, }: TaskSurveyProps<{ ranking: number[] }>) => { const cardColor = useColorModeValue("gray.50", "gray.800"); + const [ranking, setRanking] = useState(null); let messages = []; if (task.conversation) { @@ -22,13 +24,15 @@ export const EvaluateTask = ({ } useEffect(() => { - const ranking = (task.replies ?? task.prompts).map((_, idx) => idx); - onReplyChanged({ content: { ranking }, state: "DEFAULT" }); - }, [task, onReplyChanged]); - - const onRank = (newRanking: number[]) => { - onReplyChanged({ content: { ranking: newRanking }, state: "VALID" }); - }; + if (ranking === null) { + const defaultRanking = (task.replies ?? task.prompts).map((_, idx) => idx); + onReplyChanged({ ranking: defaultRanking }); + onValidityChanged("DEFAULT"); + } else { + onReplyChanged({ ranking }); + onValidityChanged("VALID"); + } + }, [task, ranking, onReplyChanged, onValidityChanged]); const sortables = task.replies ? "replies" : "prompts"; @@ -44,7 +48,7 @@ export const EvaluateTask = ({ items={task[sortables]} isDisabled={isDisabled} isEditable={isEditable} - onChange={onRank} + onChange={setRanking} className="my-8" /> diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx index 7d6394df..ff8784b8 100644 --- a/website/src/components/Tasks/LabelTask/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx @@ -1,9 +1,8 @@ -import { Box, useColorModeValue } from "@chakra-ui/react"; +import { Box, Flex, Text, useColorModeValue } from "@chakra-ui/react"; import { useEffect, useState } from "react"; import { MessageView } from "src/components/Messages"; import { MessageTable } from "src/components/Messages/MessageTable"; -import { LabelRadioGroup } from "src/components/Survey/LabelRadioGroup"; -import { LabelSliderGroup } from "src/components/Survey/LabelSliderGroup"; +import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; @@ -12,28 +11,18 @@ import { TaskType } from "src/types/Task"; export const LabelTask = ({ task, taskType, - onReplyChanged, isEditable, + onReplyChanged, + onValidityChanged, }: TaskSurveyProps<{ text: string; labels: Record; message_id: string }>) => { - const valid_labels = task.valid_labels; - const [sliderValues, setSliderValues] = useState(new Array(valid_labels.length).fill(0)); + const [sliderValues, setSliderValues] = useState(new Array(task.valid_labels.length).fill(null)); useEffect(() => { - onReplyChanged({ - content: { labels: {}, text: task.reply, message_id: task.message_id }, - state: "NOT_SUBMITTABLE", - }); - }, [task, onReplyChanged]); - - const onSliderChange = (values: number[]) => { - console.assert(valid_labels.length === sliderValues.length); - const labels = Object.fromEntries(valid_labels.map((label, i) => [label, sliderValues[i]])); - onReplyChanged({ - content: { labels, text: task.reply || task.prompt, message_id: task.message_id }, - state: "VALID", - }); - setSliderValues(values); - }; + console.assert(task.valid_labels.length === sliderValues.length); + const labels = Object.fromEntries(task.valid_labels.map((label, i) => [label, sliderValues[i]])); + onReplyChanged({ labels, text: task.reply || task.prompt, message_id: task.message_id }); + onValidityChanged(sliderValues.every((value) => value !== null) ? "VALID" : "INVALID"); + }, [task, sliderValues, onReplyChanged, onValidityChanged]); const cardColor = useColorModeValue("gray.50", "gray.800"); @@ -43,7 +32,7 @@ export const LabelTask = ({ <> {task.conversation ? ( - + )} - {task.mode === "simple" ? ( - - ) : ( - - )} + + The highlighted message: + + ); diff --git a/website/src/components/Tasks/Task/Task.tsx b/website/src/components/Tasks/Task/Task.tsx index 45fb83d0..b16711e6 100644 --- a/website/src/components/Tasks/Task/Task.tsx +++ b/website/src/components/Tasks/Task/Task.tsx @@ -6,8 +6,7 @@ import { LabelTask } from "src/components/Tasks/LabelTask"; import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskTypes"; import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning"; import { post } from "src/lib/api"; -import { TaskContent } from "src/types/Task"; -import { TaskReplyState } from "src/types/TaskReplyState"; +import { TaskContent, TaskReplyValidity } from "src/types/Task"; import useSWRMutation from "swr/mutation"; export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "VALID" | "REVIEW" | "SUBMITTED"; @@ -19,7 +18,8 @@ export interface TaskSurveyProps { taskType: TaskInfo; isEditable: boolean; isDisabled?: boolean; - onReplyChanged: (state: TaskReplyState) => void; + onReplyChanged: (content: T) => void; + onValidityChanged: (validity: TaskReplyValidity) => void; } export const Task = ({ frontendId, task, trigger, mutate }) => { @@ -44,20 +44,27 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { }); }; - const onReplyChanged = useRef((state: TaskReplyState) => { - if (taskStatus === "SUBMITTED") return; + const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID"; + const submitted = taskStatus === "SUBMITTED"; - replyContent.current = state?.content; - if (state === null) { - if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE"); - } else if (state.state === "DEFAULT") { - if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT"); - } else if (state.state === "VALID") { - if (taskStatus !== "VALID") setTaskStatus("VALID"); - } else if (state.state === "INVALID") { - setTaskStatus("NOT_SUBMITTABLE"); + const onValidityChanged = (validity: TaskReplyValidity) => { + if (!edit_mode) return; + switch (validity) { + case "DEFAULT": + if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT"); + break; + case "VALID": + if (taskStatus !== "VALID") setTaskStatus("VALID"); + break; + case "INVALID": + if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE"); + break; } - }).current; + }; + + const onReplyChanged = (content: TaskContent) => { + replyContent.current = content; + }; const reviewResponse = () => { switch (taskStatus) { @@ -99,9 +106,6 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { } }; - const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID"; - const submitted = taskStatus === "SUBMITTED"; - function taskTypeComponent() { switch (taskType.category) { case TaskCategory.Create: @@ -113,6 +117,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { isEditable={edit_mode} isDisabled={submitted} onReplyChanged={onReplyChanged} + onValidityChanged={onValidityChanged} /> ); case TaskCategory.Evaluate: @@ -124,6 +129,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { isEditable={edit_mode} isDisabled={submitted} onReplyChanged={onReplyChanged} + onValidityChanged={onValidityChanged} /> ); case TaskCategory.Label: @@ -135,6 +141,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { isEditable={edit_mode} isDisabled={submitted} onReplyChanged={onReplyChanged} + onValidityChanged={onValidityChanged} /> ); } diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index d10159d9..cfa5982a 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -162,7 +162,7 @@ export const TaskInfos: TaskInfo[] = [ category: TaskCategory.Label, pathname: "/label/label_prompter_reply", help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting", - overview: "Read the following conversation and then answer the question about the last prompt in the discussion.", + overview: "Read the following conversation and then answer the question about the last reply in the discussion.", type: "label_prompter_reply", mode: "simple", update_type: "text_labels", @@ -173,7 +173,7 @@ export const TaskInfos: TaskInfo[] = [ category: TaskCategory.Label, pathname: "/label/label_assistant_reply", help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting", - overview: "Read the following conversation and then answer the question about the last prompt in the discussion.", + overview: "Read the following conversation and then answer the question about the last reply in the discussion.", type: "label_assistant_reply", mode: "simple", update_type: "text_labels", diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 8e5ada44..12e37db0 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -35,4 +35,6 @@ export interface TaskResponse { task: Task; } +export type TaskReplyValidity = "DEFAULT" | "VALID" | "INVALID"; + export type AvailableTasks = { [taskType in TaskType]: number }; diff --git a/website/src/types/TaskReplyState.ts b/website/src/types/TaskReplyState.ts deleted file mode 100644 index 100aed11..00000000 --- a/website/src/types/TaskReplyState.ts +++ /dev/null @@ -1,22 +0,0 @@ -export interface TaskReplyNotSubmittable { - content: T; - state: "NOT_SUBMITTABLE"; -} -export interface TaskReplyValid { - content: T; - state: "VALID"; -} -export interface TaskReplyDefault { - content: T; - state: "DEFAULT"; -} -export interface TaskReplyInValid { - content: T; - state: "INVALID"; -} - -export type TaskReplyState = - | TaskReplyNotSubmittable - | TaskReplyValid - | TaskReplyDefault - | TaskReplyInValid; From 25cf9eb95361860e67cdbc50434befc8b659b4b4 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Mon, 23 Jan 2023 21:23:15 +1100 Subject: [PATCH 038/111] website: Switch to radio buttons for likert style labeling --- .../e2e/tasks/label_assistant_reply.cy.ts | 2 +- .../e2e/tasks/label_initial_prompt.cy.ts | 2 +- .../e2e/tasks/label_prompter_reply.cy.ts | 2 +- website/cypress/e2e/tasks/random.cy.ts | 5 +- .../src/components/Buttons/LikertButtons.tsx | 41 ++++------ .../src/components/Survey/LabelInputGroup.tsx | 77 ++++++++++--------- 6 files changed, 61 insertions(+), 68 deletions(-) diff --git a/website/cypress/e2e/tasks/label_assistant_reply.cy.ts b/website/cypress/e2e/tasks/label_assistant_reply.cy.ts index 3018f8f5..422db37c 100644 --- a/website/cypress/e2e/tasks/label_assistant_reply.cy.ts +++ b/website/cypress/e2e/tasks/label_assistant_reply.cy.ts @@ -13,7 +13,7 @@ describe("labeling assistant replies", () => { cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option - cy.wrap(label).find('[aria-roledescription="radio"]').eq(3).click(); + cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); }); cy.get('[data-cy="review"]').click(); diff --git a/website/cypress/e2e/tasks/label_initial_prompt.cy.ts b/website/cypress/e2e/tasks/label_initial_prompt.cy.ts index 7f66ebaf..be1cf9bb 100644 --- a/website/cypress/e2e/tasks/label_initial_prompt.cy.ts +++ b/website/cypress/e2e/tasks/label_initial_prompt.cy.ts @@ -13,7 +13,7 @@ describe("labeling initial prompts", () => { cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option - cy.wrap(label).find('[aria-roledescription="radio"]').eq(3).click(); + cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); }); cy.get('[data-cy="review"]').click(); diff --git a/website/cypress/e2e/tasks/label_prompter_reply.cy.ts b/website/cypress/e2e/tasks/label_prompter_reply.cy.ts index dbb2fb17..a3c06cb3 100644 --- a/website/cypress/e2e/tasks/label_prompter_reply.cy.ts +++ b/website/cypress/e2e/tasks/label_prompter_reply.cy.ts @@ -13,7 +13,7 @@ describe("labeling prompter replies", () => { cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option - cy.wrap(label).find('[aria-roledescription="radio"]').eq(3).click(); + cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); }); cy.get('[data-cy="review"]').click(); diff --git a/website/cypress/e2e/tasks/random.cy.ts b/website/cypress/e2e/tasks/random.cy.ts index aad2d23c..7e14dd1d 100644 --- a/website/cypress/e2e/tasks/random.cy.ts +++ b/website/cypress/e2e/tasks/random.cy.ts @@ -46,10 +46,7 @@ describe("handles random tasks", () => { case "label-task": { cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option - cy.wrap(label) - .find('[aria-roledescription="radio"]') - .eq(3) - .click(); + cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); }); cy.get('[data-cy="review"]').click(); diff --git a/website/src/components/Buttons/LikertButtons.tsx b/website/src/components/Buttons/LikertButtons.tsx index 6b1ba319..150452dc 100644 --- a/website/src/components/Buttons/LikertButtons.tsx +++ b/website/src/components/Buttons/LikertButtons.tsx @@ -1,38 +1,31 @@ -import { Button, SimpleGrid } from "@chakra-ui/react"; -import { PropsWithChildren, ReactNode } from "react"; +import { Radio, RadioGroup } from "@chakra-ui/react"; +import { PropsWithChildren } from "react"; export const LikertButtons = ({ isDisabled, - options, - value, + count, onChange, "data-cy": dataCy, }: PropsWithChildren<{ isDisabled: boolean; - options: ReactNode[]; - value: number; + count: number; onChange: (value: number) => void; "data-cy"?: string; }>) => { + const valueMap = Object.fromEntries(Array.from({ length: count }, (_, idx) => [`${idx}`, idx / (count - 1)])); + return ( - - {options.map((option, idx) => { - const indexValue = idx / (options.length - 1); - return ( - - ); + { + onChange(valueMap[value]); + }} + style={{ display: "flex", justifyContent: "space-between" }} + > + {Object.keys(valueMap).map((value) => { + return ; })} - + ); }; diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelInputGroup.tsx index 08d8a628..58a78260 100644 --- a/website/src/components/Survey/LabelInputGroup.tsx +++ b/website/src/components/Survey/LabelInputGroup.tsx @@ -1,7 +1,6 @@ -import { Box, Flex, Grid, Spacer, Text, useColorModeValue, VStack } from "@chakra-ui/react"; +import { Box, Grid, GridItem, Text, useColorModeValue } from "@chakra-ui/react"; import React from "react"; import { useState } from "react"; -import { TbChevronLeft, TbChevronRight, TbChevronsLeft, TbChevronsRight } from "react-icons/tb"; import { LikertButtons } from "src/components/Buttons/LikertButtons"; import { Explain } from "src/components/Explain"; @@ -25,7 +24,7 @@ const label_messages: { [label: string]: LabelInfo; } = { spam: { - zeroText: "Suitable for Training", + zeroText: "Not Spam", zeroDescription: ["Suitable for training Open Assistant."], oneText: "Spam", oneDescription: [ @@ -53,14 +52,14 @@ const label_messages: { violence: { zeroText: "Harmless", zeroDescription: [], - oneText: "Encourages Violence", + oneText: "Violent", oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], inverted: true, }, excessive_harm: { zeroText: "Safe", zeroDescription: [], - oneText: "Causes Harm", + oneText: "Harmful", oneDescription: [ "Content likely to cause excessive harm not justifiable in the context", "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", @@ -70,14 +69,14 @@ const label_messages: { sexual_content: { zeroText: "Non Sexual", zeroDescription: [], - oneText: "Sexual Content", + oneText: "Sexual", oneDescription: ["Contains sexual content"], inverted: true, }, toxicity: { - zeroText: "Non Toxic", + zeroText: "Polite", zeroDescription: [], - oneText: "Rude / Toxic", + oneText: "Rude", oneDescription: ["Contains rude, abusive, profane or insulting content"], inverted: true, }, @@ -98,7 +97,7 @@ const label_messages: { humor: { zeroText: "Serious", zeroDescription: [], - oneText: "Humorous / Sarcastic", + oneText: "Humorous", oneDescription: ["Contains humorous content including sarcasm"], inverted: false, }, @@ -115,7 +114,7 @@ const label_messages: { threat: { zeroText: "Safe", zeroDescription: [], - oneText: "Contains Threat", + oneText: "Threatening", oneDescription: ["Contains a threat against a person or persons"], inverted: true, }, @@ -159,34 +158,38 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label if (inverted) [textA, textB, descriptionA, descriptionB] = [textB, textA, descriptionB, descriptionA]; return ( - - - - {textA} + + + + {textA} {descriptionA.length > 0 ? : null} - - {textB} - {descriptionB.length > 0 ? : null} - - , - , - "", - , - , - ]} - data-cy="label-options" - value={labelValues[idx] === null ? null : inverted ? 1 - labelValues[idx] : labelValues[idx]} - onChange={(value) => { - const newState = labelValues.slice(); - newState[idx] = value === null ? null : inverted ? 1 - value : value; - onChange(newState); - setLabelValues(newState); - }} - /> - +
+ + { + const newState = labelValues.slice(); + newState[idx] = value === null ? null : inverted ? 1 - value : value; + onChange(newState); + setLabelValues(newState); + }} + /> + + + + {textB} + {descriptionB.length > 0 ? : null} + + + ); })} From a6d23821beb3f2c4f1c97aa3ee43f7f7d21a4209 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Mon, 23 Jan 2023 22:49:30 +1100 Subject: [PATCH 039/111] website: Highlight target message --- website/src/components/Messages/MessageTable.tsx | 12 +++++++++--- .../src/components/Messages/MessageTableEntry.tsx | 8 ++++++-- website/src/components/Tasks/CreateTask.tsx | 2 +- website/src/components/Tasks/EvaluateTask.tsx | 2 +- website/src/components/Tasks/LabelTask/LabelTask.tsx | 1 + website/styles/Theme/colors.tsx | 2 ++ 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index ed98752c..acf92e05 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -5,13 +5,19 @@ import { Message } from "src/types/Conversation"; interface MessageTableProps { messages: Message[]; enableLink?: boolean; + highlightLastMessage?: boolean; } -export function MessageTable({ messages, enableLink }: MessageTableProps) { +export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) { return ( - {messages.map((item) => ( - + {messages.map((item, idx) => ( + ))} ); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 1205991e..77202c44 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,14 +1,15 @@ -import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; +import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; -import Link from "next/link"; import { useRouter } from "next/router"; import { useCallback, useMemo } from "react"; import { FlaggableElement } from "src/components/FlaggableElement"; import { Message } from "src/types/Conversation"; +import { colors } from "styles/Theme/colors"; interface MessageTableEntryProps { item: Message; enabled?: boolean; + highlight?: boolean; } export function MessageTableEntry(props: MessageTableEntryProps) { @@ -37,6 +38,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) { ), [borderColor, inlineAvatar, item.is_assistant] ); + const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight); return ( @@ -48,6 +50,8 @@ export function MessageTableEntry(props: MessageTableEntryProps) { p="4" borderRadius="md" bg={item.is_assistant ? backgroundColor : backgroundColor2} + outline={props.highlight && "2px solid black"} + outlineColor={highlightColor} onClick={props.enabled && goToMessage} _hover={props.enabled && { cursor: "pointer", opacity: 0.9 }} whiteSpace="pre-wrap" diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 6cbead52..d52c6bb8 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -36,7 +36,7 @@ export const CreateTask = ({ {task.conversation ? ( - + ) : null} diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 6ec92a96..497c57e3 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -38,7 +38,7 @@ export const EvaluateTask = ({ - + ) : ( diff --git a/website/styles/Theme/colors.tsx b/website/styles/Theme/colors.tsx index acadfa2b..7f82ebce 100644 --- a/website/styles/Theme/colors.tsx +++ b/website/styles/Theme/colors.tsx @@ -4,11 +4,13 @@ export const colors = { btn: "gray.50", div: "white", text: "black", + highlight: "blue.400", }, dark: { bg: "gray.900", btn: "gray.600", div: "gray.700", text: "gray.200", + highlight: "blue.500", }, }; From 957bd25793500f4232ab981ce64bca2365e8d4b3 Mon Sep 17 00:00:00 2001 From: notmd Date: Mon, 23 Jan 2023 19:30:22 +0700 Subject: [PATCH 040/111] fix wrong import --- website/src/pages/dashboard.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 2c5bea5e..32774f4a 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,7 +1,7 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; +import { useTranslation } from "next-i18next"; import { useEffect, useMemo, useState } from "react"; -import { useTranslation } from "react-i18next"; import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; From d1d185edb6a4e3b3787f09a8fe0e6ea47422c80d Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Mon, 23 Jan 2023 05:24:43 -0800 Subject: [PATCH 041/111] Adding MT Sample clinical note dataset (#804) * Adding clinical note dataset * Fix flake8 issues * Fix prepare.py for straggling commas, replace assistant with Rosey in prompt Co-authored-by: Yada P --- .../datasets/mt_note_generation/README.md | 101 ++++++++++++++ .../datasets/mt_note_generation/__init__.py | 0 .../datasets/mt_note_generation/hub.py | 21 +++ .../mt_note_generation/mt_note_generation.py | 123 ++++++++++++++++++ .../datasets/mt_note_generation/prepare.py | 84 ++++++++++++ train_toxicity_model.py | 0 6 files changed, 329 insertions(+) create mode 100644 openassistant/datasets/mt_note_generation/README.md create mode 100644 openassistant/datasets/mt_note_generation/__init__.py create mode 100644 openassistant/datasets/mt_note_generation/hub.py create mode 100644 openassistant/datasets/mt_note_generation/mt_note_generation.py create mode 100644 openassistant/datasets/mt_note_generation/prepare.py create mode 100644 train_toxicity_model.py diff --git a/openassistant/datasets/mt_note_generation/README.md b/openassistant/datasets/mt_note_generation/README.md new file mode 100644 index 00000000..74754417 --- /dev/null +++ b/openassistant/datasets/mt_note_generation/README.md @@ -0,0 +1,101 @@ +--- +annotations_creators: + - no-annotation +language: + - en +language_creators: + - machine-generated +license: + - mit +multilinguality: + - monolingual +pretty_name: "MT Note Generation" +size_categories: + - <500 +source_datasets: + - mt_samples +tags: + - open-assistant +task_categories: + - conversational +task_ids: + - dialogue-generation +dataset_info: + features: + - name: conversation + dtype: string + splits: + - name: train + num_bytes: 1060800 + num_examples: 270 + - name: test + num_bytes: 339323 + num_examples: 90 + - name: validation + num_bytes: 337018 + num_examples: 90 +--- + +# Dataset Card for MT Samples Clinical Note Generation + +## Dataset Description + +- **Repository:** + [Code](https://github.com/LAION-AI/Open-Assistant/openassistant/datasets/mt_note_generation/) + +## Dataset Summary + +MT Note Generation is a set of synthetic dialogues between Assistant and User +where the user asks the assistant to generate a clinical note for a patient +persona. + +This data was created by synthesizing the dialogues in +[MT Samples](https://mtsamples.com/) for clinical note generation. + +Example: + +``` +User: Write a clinical note about a patient with the following chief complaint: Stomach pain for 2 weeks. + +ASSISTANT: + +CHIEF COMPLAINT: Stomach pain for 2 weeks. HISTORY OF PRESENT ILLNESS: The patient is a 45yo Mexican man without significant past medical history who presents to the emergency room with complaints of mid-epigastric and right upper quadrant abdominal pain for the last 14 days. The pain was initially crampy and burning in character and was relieved with food intake. He also reports that it initially was associated with a sour taste in his mouth. He went to his primary care physician who prescribed cimetidine 400mg qhs x 5 days; however, this did not relieve his symptoms. In fact, the pain has worsened such that the pain now radiates to the back but is waxing and waning in duration. It is relieved with standing and ambulation and exacerbated when lying in a supine position. He reports a decrease in appetite associated with a 4 lb. wt loss over the last 2 wks. He does have nausea with only one episode of non-bilious, non-bloody emesis on day of admission. He reports a 2 wk history of subjective fever and diaphoresis. He denies any diarrhea, constipation, dysuria, melena, or hematochezia. His last bowel movement was during the morning of admission and was normal. He denies any travel in the last 9 years and sick contacts.PAST MEDICAL HISTORY: Right inguinal groin cyst removal 15 years ago. Unknown etiology. No recurrence. + +PAST SURGICAL HISTORY: Left femoral neck fracture with prosthesis secondary to a fall 4 years ago. + +FAMILY HISTORY: Mother with diabetes. No history of liver disease. No malignancies. + +SOCIAL HISTORY: The patient was born in central Mexico but moved to the United States 9 years ago. He is on disability due to his prior femoral fracture. He denies any tobacco or illicit drug use. He only drinks alcohol socially, no more than 1 drink every few weeks. He is married and has 3 healthy children. He denies any tattoos or risky sexual behavior. + +ALLERGIES: NKDA. + +MEDICATIONS: Tylenol prn (1-2 tabs every other day for the last 2 wks), Cimetidine 400mg po qhs x 5 days. + +REVIEW OF SYSTEMS: No headache, vision changes. No shortness of breath. No chest pain or palpitations. + +PHYSICAL EXAMINATION: Vitals: T 100.9-102.7 BP 136/86 Pulse 117 RR 12 98% sat on room air,Gen: Well-developed, well-nourished, no apparent distress.HEENT: Pupils equal, round and reactive to light. Anicteric. Oropharynx clear and moist.Neck: Supple. No lymphadenopathy or carotid bruits. No thyromegaly or masses.CHEST: Clear to auscultation bilaterally.CV: Tachycardic but regular rhythm, normal S1/S2, no murmurs/rubs/gallops.Abd: Soft, active bowel sounds. Tender in the epigastrium and right upper quadrant with palpation associated with slight guarding. No rebound tenderness. No hepatomegaly. No splenomegaly.Rectal: Stool was brown and guaiac negative.Ext: No cyanosis/clubbing/edema.Neurological: He was alert and oriented x3. CN II-XII intact. Normal 2+ DTRs. No focal neurological deficit.Skin: No jaundice. No skin rashes or lesions. + +IMAGING DATA:CT Abdomen with contrast ( 11/29/03 ): There is a 6x6 cm multilobular hypodense mass seen at the level of the hepatic hilum and caudate lobe which is resulting in mass effect with dilatation of the intrahepatic radicals of the left lobe of the liver. The rest of the liver parenchyma is homogeneous. The gallbladder, pancreas, spleen, adrenal glands and kidneys are within normal limits. The retroperitoneal vascular structures are within normal limits. There is no evidence of lymphadenopathy, free fluid or fluid collections.HOSPITAL COURSE: The patient was admitted to the hospital for further evaluation. A diagnostic procedure was performed. +``` + +## Usage + +The dataset contains one configuration, `dialogue_modeling`, which has a single +text `conversation` feature. + +## Source data + +The script modifies data from mtsamples.csv which is hosted in Kaggle: +https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions + +## Citation + +Please cite our work if you find the resources in this repository useful: + +``` +@article{pruks2023mtsamplesnotegen, + author = {Yada Pruksachatkun}, + title = {MT Samples Note Generation}, + year = {2023} +} +``` diff --git a/openassistant/datasets/mt_note_generation/__init__.py b/openassistant/datasets/mt_note_generation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/openassistant/datasets/mt_note_generation/hub.py b/openassistant/datasets/mt_note_generation/hub.py new file mode 100644 index 00000000..296ecee4 --- /dev/null +++ b/openassistant/datasets/mt_note_generation/hub.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +import datasets + + +@dataclass +class OpenAssistantConfig(datasets.BuilderConfig): + """BuilderConfig for OpenAssistant datasets.""" + + name: str = None + version: datasets.Version = None + description: str = None + schema: str = None + subset_id: str = None + + +features = datasets.Features( + { + "conversation": datasets.Value("string"), + } +) diff --git a/openassistant/datasets/mt_note_generation/mt_note_generation.py b/openassistant/datasets/mt_note_generation/mt_note_generation.py new file mode 100644 index 00000000..d78b91b8 --- /dev/null +++ b/openassistant/datasets/mt_note_generation/mt_note_generation.py @@ -0,0 +1,123 @@ +# Copyright 2023 The OpenAssistant Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MT Note Generation is a set of synthetic dialogues between Assistant and +User where the user asks the assistant to generate a clinical note for a patient persona. +""" + +import json +from typing import Dict, List, Tuple + +import datasets + +from .hub import OpenAssistantConfig, features + +_CITATION = """\ + @misc{transcribed medical transcription sample reports and examples, title={Welcome to MTSamples}, + url={https://mtsamples.com/}, + journal={Transcribed Medical Transcription Sample Reports and Examples}} +""" + +_DATASETNAME = "mt_note_generation" +_DISPLAYNAME = "MT Samples Note Generation" + +_DESCRIPTION = """\ +A dataset of instructions for generating clinical notes from MT samples. +""" + +_HOMEPAGE = "" + +_LICENSE = "mit" + +_URLS = { + _DATASETNAME: { + "train": "./data/mt_note_generation_train.jsonl", + "test": "./data/mt_note_generation_test.jsonl", + "validation": "./data/mt_note_generation_validation.jsonl", + } +} + +_SUPPORTED_TASKS = ["dialogue-modeling"] + +_VERSION = "1.0.0" + + +class MTNoteGenerationDataset(datasets.GeneratorBasedBuilder): + """A set of dialogues synthesized from the MT Samples dataset.""" + + VERSION = datasets.Version(_VERSION) + + BUILDER_CONFIGS = [ + OpenAssistantConfig( + name=f"{_DATASETNAME}_dialogue_modeling", + version=VERSION, + description=f"OpenAssistant dataset config for {_DATASETNAME}", + schema="dialogue_modeling", + subset_id=_DATASETNAME, + ) + ] + + DEFAULT_CONFIG_NAME = f"{_DATASETNAME}_dialogue_modeling" + + def _info(self) -> datasets.DatasetInfo: + + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager) -> List[datasets.SplitGenerator]: + + urls = _URLS[_DATASETNAME] + data_dir = dl_manager.download_and_extract(urls) + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # Whatever you put in gen_kwargs will be passed to _generate_examples + gen_kwargs={ + "filepath": data_dir, + "split": "train", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepath": data_dir, + "split": "test", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "filepath": data_dir, + "split": "validation", + }, + ), + ] + + def _generate_examples(self, filepath, split: str) -> Tuple[int, Dict]: + """Yields examples as (key, example) tuples.""" + if self.config.schema == "dialogue_modeling": + key = 0 + with open(filepath[split], "r", encoding="utf8") as data: + while True: + line = data.readline() + if not line: + return + yield key, json.loads(line) + key += 1 diff --git a/openassistant/datasets/mt_note_generation/prepare.py b/openassistant/datasets/mt_note_generation/prepare.py new file mode 100644 index 00000000..7f0a9146 --- /dev/null +++ b/openassistant/datasets/mt_note_generation/prepare.py @@ -0,0 +1,84 @@ +import json +import math +import os +import random +import re +import sys +from string import punctuation + +import kaggle +import pandas as pd + +CLINICAL_NOTE_GENERATION_TEMPLATE = """User: Write a clinical note about a patient with the following {section}: {section_information}. +Rosey: {note}""" + + +def preprocess(mt_dataset): + def filter_for_notes(row): + normalized_transcript = row["transcription"].lower() + if "chief complaint:" in normalized_transcript: + return True + return False + + mt_dataset = mt_dataset.dropna(subset=["description", "transcription"]) + mt_note_subset = mt_dataset[mt_dataset.apply(filter_for_notes, axis=1)] + return mt_note_subset + + +def is_chief_complaint(section): + return "chief complaint" in section.lower() + + +def get_conversations(dataset): + def normalize_transcript(x): + x = re.sub(r"\.+", ".", x) + x = re.sub(r"\,+", ",", x) + x = re.sub(r":\s+", ": ", x) + x = re.sub(r"\.\s+", ". ", x) + x = re.sub(r":(\s)*\,+", ": ", x) + x = re.sub(r"\.\,+", ". ", x) + return x + + conversations = [] + for idx in range(len(dataset)): + transcript = normalize_transcript(dataset.iloc[idx]["transcription"]) + sections = re.findall(r"\b[A-Z]+(?: [A-Z]+)*:", transcript) + if len(sections) >= 2: + note_prompt = transcript.split(sections[0])[1].split(sections[1])[0] + else: + continue + section_name = sections[0].lower().strip(punctuation) + if len(note_prompt.split(" ")) > 30 and is_chief_complaint(section_name): + # There are some chief complaints that seem to be HPI + section_name = "history of present illness" + conversations.append( + CLINICAL_NOTE_GENERATION_TEMPLATE.format( + section=section_name, section_information=note_prompt, note=transcript + ) + ) + return conversations + + +def main(output_dir: str = "data"): + """Download and prepare the dataset for use.""" + os.makedirs(output_dir, exist_ok=True) + kaggle.api.dataset_download_files("tboyle10/medicaltranscriptions", "data", unzip=True) + mt_samples = preprocess(pd.read_csv("mtsamples.csv")) + conversations = get_conversations(mt_samples) + random.shuffle(conversations) + train_limit = math.ceil(len(conversations) * 0.6) + dev_limit = math.ceil(len(conversations) * 0.8) + train, validation, test = ( + conversations[:train_limit], + conversations[train_limit:dev_limit], + conversations[dev_limit:], + ) + splits = {"train": train, "validation": validation, "test": test} + for split in ["train", "validation", "test"]: + with open(f"{output_dir}/mt_note_generation_{split}.jsonl", "w") as f: + for conversation in splits[split]: + f.write(f"{json.dumps({'conversation': conversation})}\n") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/train_toxicity_model.py b/train_toxicity_model.py new file mode 100644 index 00000000..e69de29b From 40dbe55a9013ae140d0e5b46a1c08c64d8e082b1 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Mon, 23 Jan 2023 13:30:29 +0000 Subject: [PATCH 042/111] Categorise notebooks as documentation to improve project language breakdown (#860) --- .gitattributes | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitattributes b/.gitattributes index 6313b56c..7bc47e3f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ * text=auto eol=lf +*.ipynb linguist-documentation From cd2e883e9d3eeabf7056eefc91cfaced83a6636d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Mon, 23 Jan 2023 14:32:26 +0100 Subject: [PATCH 043/111] remove empty file in root folder --- train_toxicity_model.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 train_toxicity_model.py diff --git a/train_toxicity_model.py b/train_toxicity_model.py deleted file mode 100644 index e69de29b..00000000 From 037ec3c96e6c739f8e7a35ed8ffc44ba9a067285 Mon Sep 17 00:00:00 2001 From: notmd Date: Mon, 23 Jan 2023 21:10:56 +0700 Subject: [PATCH 044/111] wip --- website/public/locales/en/leaderboard.json | 3 +- website/src/components/DataTable.tsx | 80 +++++++------- .../LeaderboardGridCell.tsx | 101 +++++++----------- website/src/pages/leaderboard.tsx | 50 +++++---- 4 files changed, 105 insertions(+), 129 deletions(-) diff --git a/website/public/locales/en/leaderboard.json b/website/public/locales/en/leaderboard.json index c2dd0832..d1d7ed92 100644 --- a/website/public/locales/en/leaderboard.json +++ b/website/public/locales/en/leaderboard.json @@ -7,5 +7,6 @@ "rank": "Rank", "score": "Score", "user": "User", - "weekly": "Weekly" + "weekly": "Weekly", + "prompt_tasks": "Prompt Tasks" } diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index 466393eb..d29a2813 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -1,8 +1,6 @@ import { Box, Button, - Card, - CardBody, Flex, FormControl, FormLabel, @@ -49,6 +47,7 @@ export type DataTableProps = { onFilterChange?: (items: FilterItem[]) => void; disableNext?: boolean; disablePrevious?: boolean; + disablePagination?: boolean; }; export const DataTable = ({ @@ -61,6 +60,7 @@ export const DataTable = ({ onFilterChange, disableNext, disablePrevious, + disablePagination, }: DataTableProps) => { const { getHeaderGroups, getRowModel } = useReactTable({ data, @@ -79,8 +79,8 @@ export const DataTable = ({ onFilterChange(newValues); }; return ( - - + <> + {disablePagination && ( - - - {caption} - - {getHeaderGroups().map((headerGroup) => ( - - {headerGroup.headers.map((header) => ( - - ))} - - ))} - - - {getRowModel().rows.map((row) => ( - - {row.getVisibleCells().map((cell) => ( - - ))} - - ))} - -
- - {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} - {(header.column.columnDef as DataTableColumnDef).filterable && ( - value.id === header.id)?.value ?? ""} - onChange={(value) => handleFilterChange({ id: header.id, value })} - label={flexRender(header.column.columnDef.header, header.getContext())} - > - )} - -
{flexRender(cell.column.columnDef.cell, cell.getContext())}
-
-
-
+ )} + + + {caption} + + {getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + ))} + + ))} + + + {getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + ))} + + ))} + +
+ + {header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())} + {(header.column.columnDef as DataTableColumnDef).filterable && ( + value.id === header.id)?.value ?? ""} + onChange={(value) => handleFilterChange({ id: header.id, value })} + label={flexRender(header.column.columnDef.header, header.getContext())} + > + )} + +
{flexRender(cell.column.columnDef.cell, cell.getContext())}
+
+ ); }; diff --git a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx index 7886784a..a65b7c2a 100644 --- a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx +++ b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx @@ -1,90 +1,61 @@ -import { Table, TableContainer, Tbody, Td, Text, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react"; +import { CircularProgress } from "@chakra-ui/react"; +import { createColumnHelper } from "@tanstack/react-table"; import { useTranslation } from "next-i18next"; import React, { useMemo } from "react"; -import { useTable } from "react-table"; import { get } from "src/lib/api"; -import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; +import { LeaderboardEntity, LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import useSWRImmutable from "swr/immutable"; -const getColumns = (t) => [ - { - Header: t("rank"), - accessor: "rank", - style: { width: "90px" }, - }, - { - Header: t("score"), - accessor: "leader_score", - style: { width: "90px" }, - }, - { - Header: t("user"), - accessor: "display_name", - }, -]; +import { DataTable } from "../DataTable"; + +const columnHelper = createColumnHelper(); /** * Presents a grid of leaderboard entries with more detailed information. */ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => { - const { t } = useTranslation(["leaderboard", "common"]); - const { data: reply } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { + const { t } = useTranslation("leaderboard"); + const { + data: reply, + isLoading, + error, + } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { revalidateOnMount: true, }); - const columns = useMemo(() => getColumns(t), [t]); - - const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ - columns, - data: reply?.leaderboard ?? [], - }); - - const backgroundColor = useColorModeValue("white", "gray.800"); + const columns = useMemo( + () => [ + columnHelper.accessor("rank", { + header: t("rank"), + }), + columnHelper.accessor("display_name", { + header: t("user"), + }), + columnHelper.accessor("leader_score", { + header: t("score"), + }), + columnHelper.accessor("prompts", { + header: t("pro"), + }), + ], + [t] + ); const lastUpdated = useMemo(() => { const val = new Date(reply?.last_updated); return t("last_updated_at", { val, formatParams: { val: { dateStyle: "full", timeStyle: "short" } } }); }, [t, reply?.last_updated]); + console.log(reply, isLoading); - if (!reply) { - return null; + if (isLoading) { + return ; } - return ( - - - - {headerGroups.map((headerGroup, idx) => ( - - {headerGroup.headers.map((column) => ( - - ))} - - ))} - + if (error) { + return Unable to load leaderboard; + } - - {rows.map((row) => { - prepareRow(row); - return ( - - {row.cells.map((cell, idx) => { - return ( - - ); - })} - - ); - })} - -
- {column.render("Header")} -
- {cell.render("Cell")} -
- {lastUpdated} -
- ); + return ; }; export { LeaderboardGridCell }; diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index e413366f..052c94ad 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -1,4 +1,4 @@ -import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; +import { Box, Card, CardBody, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; import Head from "next/head"; import { useTranslation } from "next-i18next"; import { getDashboardLayout } from "src/components/Layout"; @@ -18,29 +18,33 @@ const Leaderboard = () => { {t("leaderboard")} - - - {t("daily")} - {t("weekly")} - {t("monthly")} - {t("overall")} - + + + + + {t("daily")} + {t("weekly")} + {t("monthly")} + {t("overall")} + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + ); From b7056eccd18466aa85509a1973cd9b59110e75b4 Mon Sep 17 00:00:00 2001 From: notmd Date: Tue, 24 Jan 2023 02:04:00 +0700 Subject: [PATCH 045/111] Show more stats in leaderboard table --- website/package-lock.json | 19 ----------- website/package.json | 1 - website/public/locales/en/leaderboard.json | 4 ++- ...erboardTable.tsx => LeaderboardWidget.tsx} | 22 +++++-------- website/src/components/Dashboard/index.ts | 2 +- .../components/LeaderboardGridCell/index.tsx | 1 - .../LeaderboardTable.tsx} | 17 ++++++---- .../src/components/LeaderboardTable/index.tsx | 1 + website/src/components/UserTable.tsx | 32 ++++++++++--------- website/src/lib/oasst_api_client.ts | 10 ++++-- website/src/pages/api/leaderboard.ts | 2 +- website/src/pages/dashboard.tsx | 4 +-- website/src/pages/leaderboard.tsx | 10 +++--- 13 files changed, 56 insertions(+), 69 deletions(-) rename website/src/components/Dashboard/{LeaderboardTable.tsx => LeaderboardWidget.tsx} (50%) delete mode 100644 website/src/components/LeaderboardGridCell/index.tsx rename website/src/components/{LeaderboardGridCell/LeaderboardGridCell.tsx => LeaderboardTable/LeaderboardTable.tsx} (74%) create mode 100644 website/src/components/LeaderboardTable/index.tsx diff --git a/website/package-lock.json b/website/package-lock.json index 71177bf5..348d9fad 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -46,7 +46,6 @@ "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", "react-i18next": "^12.1.4", - "react-table": "^7.8.0", "sharp": "^0.31.3", "swr": "^2.0.0", "tailwindcss": "^3.2.4", @@ -32756,18 +32755,6 @@ } } }, - "node_modules/react-table": { - "version": "7.8.0", - "resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz", - "integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - }, - "peerDependencies": { - "react": "^16.8.3 || ^17.0.0-0 || ^18.0.0" - } - }, "node_modules/read-cache": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", @@ -62197,12 +62184,6 @@ "tslib": "^2.0.0" } }, - "react-table": { - "version": "7.8.0", - "resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz", - "integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==", - "requires": {} - }, "read-cache": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", diff --git a/website/package.json b/website/package.json index f2499920..40d279fe 100644 --- a/website/package.json +++ b/website/package.json @@ -63,7 +63,6 @@ "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", "react-i18next": "^12.1.4", - "react-table": "^7.8.0", "sharp": "^0.31.3", "swr": "^2.0.0", "tailwindcss": "^3.2.4", diff --git a/website/public/locales/en/leaderboard.json b/website/public/locales/en/leaderboard.json index d1d7ed92..06097642 100644 --- a/website/public/locales/en/leaderboard.json +++ b/website/public/locales/en/leaderboard.json @@ -8,5 +8,7 @@ "score": "Score", "user": "User", "weekly": "Weekly", - "prompt_tasks": "Prompt Tasks" + "prompt": "Prompts", + "reply": "Replies", + "label": "Labels" } diff --git a/website/src/components/Dashboard/LeaderboardTable.tsx b/website/src/components/Dashboard/LeaderboardWidget.tsx similarity index 50% rename from website/src/components/Dashboard/LeaderboardTable.tsx rename to website/src/components/Dashboard/LeaderboardWidget.tsx index 52cf762b..aa89d585 100644 --- a/website/src/components/Dashboard/LeaderboardTable.tsx +++ b/website/src/components/Dashboard/LeaderboardWidget.tsx @@ -1,11 +1,9 @@ -import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react"; +import { Card, CardBody, Link, Text } from "@chakra-ui/react"; import NextLink from "next/link"; -import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; +import { LeaderboardTable } from "src/components/LeaderboardTable"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; -export function LeaderboardTable() { - const backgroundColor = useColorModeValue("white", "gray.700"); - const accentColor = useColorModeValue("gray.200", "gray.900"); +export function LeaderboardWidget() { return (
@@ -17,15 +15,11 @@ export function LeaderboardTable() {
- - - + + + + +
); diff --git a/website/src/components/Dashboard/index.ts b/website/src/components/Dashboard/index.ts index 84a93850..84858345 100644 --- a/website/src/components/Dashboard/index.ts +++ b/website/src/components/Dashboard/index.ts @@ -1,3 +1,3 @@ -export { LeaderboardTable } from "./LeaderboardTable"; +export { LeaderboardWidget } from "./LeaderboardWidget"; export { TaskOption } from "./TaskOption"; export { WelcomeCard } from "./WelcomeCard"; diff --git a/website/src/components/LeaderboardGridCell/index.tsx b/website/src/components/LeaderboardGridCell/index.tsx deleted file mode 100644 index c4657eb6..00000000 --- a/website/src/components/LeaderboardGridCell/index.tsx +++ /dev/null @@ -1 +0,0 @@ -export * from "./LeaderboardGridCell"; diff --git a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx b/website/src/components/LeaderboardTable/LeaderboardTable.tsx similarity index 74% rename from website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx rename to website/src/components/LeaderboardTable/LeaderboardTable.tsx index a65b7c2a..ca185be9 100644 --- a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx +++ b/website/src/components/LeaderboardTable/LeaderboardTable.tsx @@ -13,13 +13,13 @@ const columnHelper = createColumnHelper(); /** * Presents a grid of leaderboard entries with more detailed information. */ -const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => { +export const LeaderboardTable = ({ timeFrame, limit }: { timeFrame: LeaderboardTimeFrame; limit: number }) => { const { t } = useTranslation("leaderboard"); const { data: reply, isLoading, error, - } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}`, get, { + } = useSWRImmutable(`/api/leaderboard?time_frame=${timeFrame}&limit=${limit}`, get, { revalidateOnMount: true, }); @@ -35,7 +35,13 @@ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) header: t("score"), }), columnHelper.accessor("prompts", { - header: t("pro"), + header: t("prompt"), + }), + columnHelper.accessor((row) => row.replies_assistant + row.replies_prompter, { + header: t("reply"), + }), + columnHelper.accessor((row) => row.labels_full + row.labels_simple, { + header: t("label"), }), ], [t] @@ -45,7 +51,6 @@ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) const val = new Date(reply?.last_updated); return t("last_updated_at", { val, formatParams: { val: { dateStyle: "full", timeStyle: "short" } } }); }, [t, reply?.last_updated]); - console.log(reply, isLoading); if (isLoading) { return ; @@ -55,7 +60,5 @@ const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) return Unable to load leaderboard; } - return ; + return ; }; - -export { LeaderboardGridCell }; diff --git a/website/src/components/LeaderboardTable/index.tsx b/website/src/components/LeaderboardTable/index.tsx new file mode 100644 index 00000000..d8de5feb --- /dev/null +++ b/website/src/components/LeaderboardTable/index.tsx @@ -0,0 +1 @@ +export * from "./LeaderboardTable"; diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index 5e5828ea..71d4fe44 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -1,4 +1,4 @@ -import { IconButton } from "@chakra-ui/react"; +import { Card, CardBody, IconButton } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; import { Pencil } from "lucide-react"; import Link from "next/link"; @@ -90,19 +90,21 @@ export const UserTable = memo(function UserTable() { }; return ( - <> - - {error && "Unable to load users."} - + + + + {error && "Unable to load users."} + + ); }); diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 47df8584..952858f2 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -262,8 +262,14 @@ export class OasstApiClient { /** * Returns the current leaderboard ranking. */ - async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise { - return this.get(`/api/v1/leaderboards/${time_frame}`); + async fetch_leaderboard( + time_frame: LeaderboardTimeFrame, + { limit = 20 }: { limit?: number } + ): Promise { + const params = new URLSearchParams({ + limit: limit.toString(), + }); + return this.get(`/api/v1/leaderboards/${time_frame}?${params.toString()}`); } /** diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index 1ddf947e..be91e7b4 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -7,7 +7,7 @@ import { LeaderboardTimeFrame } from "src/types/Leaderboard"; */ const handler = withoutRole("banned", async (req, res) => { const time_frame = (req.query.time_frame as LeaderboardTimeFrame) ?? LeaderboardTimeFrame.day; - const info = await oasstApiClient.fetch_leaderboard(time_frame); + const info = await oasstApiClient.fetch_leaderboard(time_frame, { limit: req.query.limit as unknown as number }); res.status(200).json(info); }); diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 32774f4a..17e04c8a 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -2,7 +2,7 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; import { useTranslation } from "next-i18next"; import { useEffect, useMemo, useState } from "react"; -import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; +import { LeaderboardWidget, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; import { get } from "src/lib/api"; @@ -42,7 +42,7 @@ const Dashboard = () => { - + ); diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index 5afbc1b5..18f64bac 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -2,8 +2,8 @@ import { Box, Card, CardBody, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } import Head from "next/head"; import { useTranslation } from "next-i18next"; import { getDashboardLayout } from "src/components/Layout"; -import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import { LeaderboardTable } from "src/components/LeaderboardTable"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; const Leaderboard = () => { @@ -30,16 +30,16 @@ const Leaderboard = () => { - + - + - + - + From 0c890ae72631ab84a938dbf7a1e1b9ba097710bc Mon Sep 17 00:00:00 2001 From: notmd Date: Tue, 24 Jan 2023 02:09:33 +0700 Subject: [PATCH 046/111] fix `disablePagination` logic --- website/src/components/DataTable.tsx | 2 +- website/src/components/LeaderboardTable/LeaderboardTable.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index d29a2813..a8106317 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -80,7 +80,7 @@ export const DataTable = ({ }; return ( <> - {disablePagination && ( + {!disablePagination && ( + + + ); +}; From 1eb3f05c44d1aea680b9cc936a459112159f49ae Mon Sep 17 00:00:00 2001 From: kayjay Date: Thu, 26 Jan 2023 01:49:03 -0800 Subject: [PATCH 073/111] PR: Create notebook to convert r/changemyview data (#839) * (#737) Create notebook to convert r/changemyview data into cleaner format --- .../changemyview-builder/README.md | 37 ++ .../changemyview-builder/data_processor.ipynb | 577 ++++++++++++++++++ 2 files changed, 614 insertions(+) create mode 100644 notebooks/data-augmentation/changemyview-builder/README.md create mode 100644 notebooks/data-augmentation/changemyview-builder/data_processor.ipynb diff --git a/notebooks/data-augmentation/changemyview-builder/README.md b/notebooks/data-augmentation/changemyview-builder/README.md new file mode 100644 index 00000000..62ab9b2e --- /dev/null +++ b/notebooks/data-augmentation/changemyview-builder/README.md @@ -0,0 +1,37 @@ +# README + +## Introduction + +This program converts data obtained from the subreddit r/changemyview into a cleaner format for further data processing. The data is not clean enough to be used directly in a model yet, and additional preprocessing is required. + +## Data Format + +The cleaned data is stored in an Apache Parquet file with the following columns: + +| Column Name | Description | Data Type | +|-------------|------------------------------------------------------------------------|----------------| +| INSTRUCTION | Post title + body text | String | +| RESPONSE | Body text of comments attempting to change OP's mind of `INSTRUCTION`. | List\ | +| SOURCE | Permalink to the reddit post | String | +| METADATA | Metadata related to `RESPONSE`. | Dict\ | + +### Metadata +Currently, metadata is only broken into one category: +- `detoxify_labels`- A Dictionary of values outputted by the [Unitaryai Detoxifier](https://github.com/unitaryai/detoxify) model, fitted to every comment under any given post. + +## Usage + +To use the program, follow these instructions: + +1. **Clone the repository** - `git clone https://github.com/LAION-AI/Open-Assistant.git` +2. **Navigate to the project directory** - `cd notebooks/data-augmentation/changemyview-builder` +3. **Open the Jupyter Notebook** - `jupyter notebook data_processor.ipynb` +4. **Run the program** - Go through the notebook and run the cells + +## Contributing + +If you would like to contribute to this project, please fork the repository and submit a pull request with your changes. + +## License + +This project is licensed under the Apache-2.0 License - see the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb b/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb new file mode 100644 index 00000000..03155d40 --- /dev/null +++ b/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb @@ -0,0 +1,577 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# r/ChangeMyView data converter\n", + "Converts subreddit data into readable format for ML training\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 65, + "outputs": [], + "source": [ + "### REMEMBER: setup the .env before running this code!\n", + "\n", + "\"\"\"CONSTANTS\"\"\"\n", + "\n", + "# Set the head number to the amount of entries you want to load in minus one\n", + "ENTRIES_COUNT = 10\n", + "\n", + "# Set the threshold for toxic comments to be removed\n", + "TOXIC_THRESHOLD = 0.95" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 66, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pandas in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (1.4.4)\r\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2.8.2)\r\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2022.1)\r\n", + "Requirement already satisfied: numpy>=1.18.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (1.21.5)\r\n", + "Requirement already satisfied: six>=1.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\r\n", + "Requirement already satisfied: praw in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (7.6.1)\r\n", + "Requirement already satisfied: websocket-client>=0.54.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.58.0)\r\n", + "Requirement already satisfied: update-checker>=0.18 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.18.0)\r\n", + "Requirement already satisfied: prawcore<3,>=2.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (2.3.0)\r\n", + "Requirement already satisfied: requests<3.0,>=2.6.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from prawcore<3,>=2.1->praw) (2.28.1)\r\n", + "Requirement already satisfied: six in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from websocket-client>=0.54.0->praw) (1.16.0)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (3.3)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2022.9.24)\r\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2.0.4)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (1.26.11)\r\n", + "Requirement already satisfied: python-dotenv in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.21.0)\r\n", + "Requirement already satisfied: pyarrow in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (10.0.1)\r\n", + "Requirement already satisfied: numpy>=1.16.6 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pyarrow) (1.21.5)\r\n", + "Requirement already satisfied: detoxify in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.5.1)\r\n", + "Requirement already satisfied: transformers==4.22.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (4.22.1)\r\n", + "Requirement already satisfied: torch>=1.7.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (1.13.1)\r\n", + "Requirement already satisfied: sentencepiece>=0.1.94 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (0.1.97)\r\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.9.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.11.1)\r\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2022.7.9)\r\n", + "Requirement already satisfied: pyyaml>=5.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (6.0)\r\n", + "Requirement already satisfied: tqdm>=4.27 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (4.64.1)\r\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.12.1)\r\n", + "Requirement already satisfied: filelock in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (3.6.0)\r\n", + "Requirement already satisfied: packaging>=20.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (21.3)\r\n", + "Requirement already satisfied: numpy>=1.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (1.21.5)\r\n", + "Requirement already satisfied: requests in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2.28.1)\r\n", + "Requirement already satisfied: typing-extensions in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from torch>=1.7.0->detoxify) (4.3.0)\r\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers==4.22.1->detoxify) (3.0.9)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (3.3)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (1.26.11)\r\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2.0.4)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2022.9.24)\r\n", + "Requirement already satisfied: tqdm in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (4.64.1)\r\n" + ] + } + ], + "source": [ + "# Install any dependencies\n", + "!pip install pandas\n", + "!pip install praw\n", + "!pip install python-dotenv\n", + "!pip install pyarrow\n", + "!pip install detoxify\n", + "!pip install tqdm" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import praw\n", + "import os\n", + "from os.path import join, dirname\n", + "from dotenv import main\n", + "\n", + "# Make sure you create a .env file and fill in all the necessary information in the same folder as this script!\n", + "main.load_dotenv(join(dirname(os.path.realpath('__file__')), '.env'))\n", + "\n", + "reddit = praw.Reddit(\n", + " client_id=os.environ.get(\"CLIENT_ID\"),\n", + " client_secret=os.environ.get(\"CLIENT_SECRET\"),\n", + " user_agent=\"CMV_Scraper\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "outputs": [], + "source": [ + "# load the data\n", + "import tarfile\n", + "import os.path\n", + "import json\n", + "import re\n", + "from bz2 import BZ2File\n", + "from urllib import request\n", + "from io import BytesIO\n", + "\n", + "import numpy as np\n", + "\n", + "\n", + "fname = \"cmv.tar.bz2\"\n", + "url = \"https://chenhaot.com/data/cmv/\" + fname\n", + "\n", + "# download if not exists\n", + "if not os.path.isfile(fname):\n", + " f = BytesIO()\n", + " with request.urlopen(url) as resp, open(fname, 'wb') as f_disk:\n", + " data = resp.read()\n", + " f_disk.write(data) # save to disk too\n", + " f.write(data)\n", + " f.seek(0)\n", + "else:\n", + " f = open(fname, 'rb')\n", + "\n", + "\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 69, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kayjaymac/opt/anaconda3/lib/python3.9/bz2.py:124: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n", + " self._buffer = None\n", + "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + } + ], + "source": [ + "#tar = tarfile.open(fileobj=f, mode=\"r:bz2\")\n", + "tar = tarfile.open(fileobj=f, mode=\"r\")\n", + "\n", + "# Extract the file we are interested in\n", + "\n", + "train_fname = \"op_task/train_op_data.jsonlist.bz2\"\n", + "test_fname = \"op_task/heldout_op_data.jsonlist.bz2\"\n", + "\n", + "train_bzlist = tar.extractfile(train_fname)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 70, + "outputs": [], + "source": [ + "# Deserialize the JSON list\n", + "original_posts_train = [\n", + " json.loads(line.decode('utf-8'))\n", + " for line in BZ2File(train_bzlist)\n", + "]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 71, + "outputs": [ + { + "data": { + "text/plain": "[{'title': \"CMV: I shouldn't get a job in this economic climate because it'll be automated anyway; I should just wait for a post-scarcity utopia.\",\n 'delta_label': False,\n 'name': 't3_2rpsl8',\n 'selftext': \"I think the world is automating fast enough that a utopia will arise where no one will have to work anymore. Within the next 2 decades or so, having a job won't mean much, and most people will be artists and scientists. \\n\\nMy parents let me live with them, so I can just wait until the utopia happens.\\n\\nCMV.\"}]" + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "original_posts_train[:1]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 72, + "outputs": [], + "source": [ + "# Load the jsonlist file into a dataframe\n", + "#df = pd.read_json(original_posts_train, orient='list', lines=True)\n", + "df = pd.DataFrame(original_posts_train)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 73, + "outputs": [], + "source": [ + "# Function to check if the posts still exists on reddit\n", + "def try_get_post(post_id):\n", + " try:\n", + " submission = reddit.submission(id=post_id)\n", + " submission.name\n", + " return True\n", + " except Exception as e:\n", + " return False" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 74, + "outputs": [], + "source": [ + "# Set up the detoxifier model:\n", + "from detoxify import Detoxify" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "\n", + "# Removes > sign and the template message at the end of a message\n", + "def cleanup_body_text(cmv_post):\n", + " lines = [line for line in cmv_post.splitlines()\n", + " if not line.lstrip().startswith(\">\")\n", + " and not line.lstrip().startswith(\"____\")\n", + " and not line.lstrip().startswith(\"So go forth and CMV, noble redditors!\")\n", + " and \"edit\" not in \" \".join(line.lower().split()[:2])\n", + " ]\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "\n", + "\n", + "# Create the function that will be handling all the data gathering\n", + "def get_top_comment_and_clean_data(post_id):\n", + " #print(post_id.lstrip(\"t3_\"))\n", + " last_author = \"\"\n", + " # Grab the post\n", + " submission = reddit.submission(id=post_id.lstrip(\"t3_\"))\n", + " #print(submission.title)\n", + "\n", + " # Grab the highest rated comment on root layer\n", + " submission.submission_type = 'best'\n", + " submission.comments.replace_more(limit=0)\n", + " replies = list(submission.comments)[0].replies.list()\n", + "\n", + " # Just some variables\n", + " pros = []\n", + "\n", + " # If the post author doesn't exist this submission was deleted (submission.deleted doesn't work)\n", + " if type(submission.author) == type(None):\n", + " last_author = \"[deleted]\"\n", + " else:\n", + " last_author = submission.author.name\n", + "\n", + " is_pro_argument = False\n", + "\n", + " for comment in replies:\n", + "\n", + " # If redditor object doesn't exist, the account is invalid/deleted\n", + " if type(comment.author) != type(None):\n", + " author = comment.author.name\n", + " else:\n", + " author = \"[deleted]\"\n", + "\n", + " # Assume that whenever the user changes, they are countering the previous person\n", + " if author != last_author:\n", + " is_pro_argument = !is_pro_argument\n", + "\n", + " if author == \"[deleted]\" or author==\"DeltaBot\":\n", + " #print(\"Skipping comment...\")\n", + " continue\n", + "\n", + " # Remove meta and duplicate comments\n", + " comment.body = \" \".join([line for line in comment.body.splitlines()\n", + " if not re.search(r\"(?i)(Change\\smy\\sview|CMV)\", line)\n", + " and line not in pros # Why doesn't this line work\n", + " ])\n", + "\n", + " # Sometimes for some reason duplicate entries exist\n", + " # Also remove automated message with \"Δ\" in it\n", + "\n", + " if comment.body in pros:\n", + " #print(\"Skipping duplicate entry\")\n", + " continue\n", + "\n", + " #print(\"\\t\\t>>\\t\",comment.body)\n", + "\n", + " # Remove toxic comments\n", + " if Detoxify(\"multilingual\").predict(comment.body)[\"toxicity\"] > TOXIC_THRESHOLD:\n", + " #print(\"Identified toxic comment, ignoring...\")\n", + " comment.body = \"\"\n", + "\n", + " # Add to the respective argument type \n", + " if is_pro_argument:\n", + " pros.append(comment.body)\n", + " \n", + " last_author = comment.author.name\n", + " \n", + " # Pros = arguments for the Title of this post\n", + " # Cons = arguments against the title of this post\n", + "\n", + " pros.append(comment.body)\n", + " return pros" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading in 10 posts\n" + ] + } + ], + "source": [ + "print(f\"Loading in {ENTRIES_COUNT} posts\")\n", + "dataset = df.head(ENTRIES_COUNT)\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 77, + "outputs": [], + "source": [ + "# the name column does some weird sh** because dataframes already have a name property, so migrate to a different column name\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "dataset[\"post_id\"] = dataset[\"name\"]\n", + "warnings.filterwarnings('default')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading in data... This will take a while.\n" + ] + }, + { + "data": { + "text/plain": " 0%| | 0/10 [00:00:29: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + "/Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages/torch/serialization.py:997: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n", + " storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()\n", + "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 7min 49s, sys: 2min 29s, total: 10min 19s\n", + "Wall time: 8min 45s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "from tqdm.auto import tqdm\n", + "# Reset variables for if we run this multiple times\n", + "all_pros = []\n", + "all_names = []\n", + "all_titles = []\n", + "all_sources = []\n", + "\n", + "print(\"Loading in data... This will take a while.\")\n", + "\n", + "for i in tqdm(range(dataset.shape[0])):\n", + "\n", + " post = dataset.iloc[i]\n", + " modified_title = post.title.replace('CMV', \"Change my mind\")\n", + " #print(f\"\\n Loading entry {i+1}/{dataset.shape[0]}:\\n\\t\\\"{modified_title}\\\"\")\n", + "\n", + " if type(post) == type(None):\n", + " continue\n", + "\n", + " assert(post.post_id != i)\n", + "\n", + " pros = get_top_comment_and_clean_data(post.post_id)\n", + "\n", + " if post.title == \"[deleted]\":\n", + " continue\n", + "\n", + " pros = \" \".join([*set(pros)])\n", + " pros = pros.replace(\"[deleted]\",\"\")\n", + "\n", + " post.selftext = cleanup_body_text(post.selftext)\n", + " all_titles.append(modified_title + \" \" + post.selftext)\n", + " all_pros.append(pros)\n", + " all_names.append(post.name)\n", + " all_sources.append(f\"https://reddit.com/r/changemyview/comments/{post.post_id}\")\n", + " #print(post.title)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "outputs": [ + { + "data": { + "text/plain": "'it\\'s already been signed. They even claim to be adhering to it, though they\\'ve been found to be violating it before. There is no such thing as \"de facto acceptance of Israel\\'s nuclear program.\" the Non-Proliferation Treaty is only binding for signatory states. Israel is not a signatory. Article 10 of the NPT allows them to withdraw if they so choose. they have not done so. a whole new country which explicitly has a right to withdraw from the NPT and has not chosen to do so. It\\'s more accurate, I think, to say that the problem with Iran here from a legal standpoint is that they aren\\'t honoring their own commitments, rather than that they\\'re building weapons. They could pull out of the NPT at any time, and the ball would be essentially in America\\'s court, because their nuclear program would no longer be illegal by international legal standards. However, Iran insists both on developing nukes *and* remaining an NPT signatory non-nuclear state, and that\\'s what makes their program illegal. I\\'d also like to clarify that I\\'m not making an ethical argument here, this is just how international law currently works. because international law doesn\\'t require states to sign treaties, it only requires them to adhere to treaties they\\'ve already signed. Israel isn\\'t defying the UN, at least not in this particular case. Think of the NPT less like a standard law within a state and more like a contract. Once you\\'ve signed, you\\'re bound by the contract, but if you never sign it then you haven\\'t broken a law, you\\'ve just decided not to agree to the terms you were offered. > Because Iran did sign the treaty, and thus are bound by it. They signed on July 1, 1968. Hmm. So is the argument here that it\\'s not \"ok\" for Iran to have a nuke, since they signed treaty not to do so. But it\\'s \"ok\" for Israel to have one because they never signed such thing? Can\\'t quite put my finger on it, but doesn\\'t seem quite right this one.'" + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_pros[1]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 80, + "outputs": [], + "source": [ + "# Place it all into a Pandas Dataframe\n", + "clean_df = pd.DataFrame({\n", + " \"INSTRUCTION\": all_titles,\n", + " \"RESPONSE\": all_pros,\n", + " \"SOURCE\": all_sources\n", + "}, index=all_names\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Apache Paquete file\n", + "\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "\n", + "table = pa.Table.from_pandas(clean_df)\n", + "pq.write_table(table,\"output.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "outputs": [ + { + "data": { + "text/plain": " INSTRUCTION \\\n0 Change my mind: I shouldn't get a job in this ... \n1 Change my mind: Iran has the right to develop ... \n2 Change my mind: The events in Paris suck...but... \n3 Change my mind: It is ok to hate a religion so... \n4 Change my mind: There is no productive reason ... \n5 Change my mind: Diet soda is perfectly healthy... \n6 Change my mind:Essential Oils are bullshit My ... \n7 Change my mind: I think the Paris shooting mak... \n8 Change my mind: Printing an image of the Musli... \n9 Change my mind: Philosophy has no tangible val... \n\n RESPONSE \\\n0 That is what someone in the 1500s would have s... \n1 it's already been signed. They even claim to b... \n2 Hm I guess I made the OP incorrectly. The mai... \n3 I don't understand your analogy. Promoting a ... \n4 ∆ I hadn't thought it from a \"let's trick peop... \n5 Thanks for a fresh argument! I hadn't conside... \n6 Most do. Some smell kinda funky. \n7 I already said in different comments that thi... \n8 The first bacon sandwich came about because 9... \n9 >Why restrict it to 50 years? I can name all s... \n\n SOURCE \n0 https://reddit.com/r/changemyview/comments/t3_... \n1 https://reddit.com/r/changemyview/comments/t3_... \n2 https://reddit.com/r/changemyview/comments/t3_... \n3 https://reddit.com/r/changemyview/comments/t3_... \n4 https://reddit.com/r/changemyview/comments/t3_... \n5 https://reddit.com/r/changemyview/comments/t3_... \n6 https://reddit.com/r/changemyview/comments/t3_... \n7 https://reddit.com/r/changemyview/comments/t3_... \n8 https://reddit.com/r/changemyview/comments/t3_... \n9 https://reddit.com/r/changemyview/comments/t3_... ", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
INSTRUCTIONRESPONSESOURCE
0Change my mind: I shouldn't get a job in this ...That is what someone in the 1500s would have s...https://reddit.com/r/changemyview/comments/t3_...
1Change my mind: Iran has the right to develop ...it's already been signed. They even claim to b...https://reddit.com/r/changemyview/comments/t3_...
2Change my mind: The events in Paris suck...but...Hm I guess I made the OP incorrectly. The mai...https://reddit.com/r/changemyview/comments/t3_...
3Change my mind: It is ok to hate a religion so...I don't understand your analogy. Promoting a ...https://reddit.com/r/changemyview/comments/t3_...
4Change my mind: There is no productive reason ...∆ I hadn't thought it from a \"let's trick peop...https://reddit.com/r/changemyview/comments/t3_...
5Change my mind: Diet soda is perfectly healthy...Thanks for a fresh argument! I hadn't conside...https://reddit.com/r/changemyview/comments/t3_...
6Change my mind:Essential Oils are bullshit My ...Most do. Some smell kinda funky.https://reddit.com/r/changemyview/comments/t3_...
7Change my mind: I think the Paris shooting mak...I already said in different comments that thi...https://reddit.com/r/changemyview/comments/t3_...
8Change my mind: Printing an image of the Musli...The first bacon sandwich came about because 9...https://reddit.com/r/changemyview/comments/t3_...
9Change my mind: Philosophy has no tangible val...>Why restrict it to 50 years? I can name all s...https://reddit.com/r/changemyview/comments/t3_...
\n
" + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test to see if it was sucessful\n", + "table = pq.read_table(\"output.parquet\")\n", + "table.to_pandas()" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From 5d4f74f9d60bb27c3d29d3c50f3b855869d59861 Mon Sep 17 00:00:00 2001 From: MattAlexMiracle Date: Thu, 26 Jan 2023 10:50:25 +0100 Subject: [PATCH 074/111] Ranked pairs (#933) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * commented out legacy numerical solver * added comments and task_scheduling for selecting which task to serve to users * removed standalone task weighting * pre-commit hook rerun * fixed ranking * fix index error * ranking fix * fix typo Co-authored-by: Alexander Mattick Co-authored-by: Andreas Köpf --- backend/oasst_backend/utils/ranking.py | 33 ++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/backend/oasst_backend/utils/ranking.py b/backend/oasst_backend/utils/ranking.py index 5538d7a3..0bb94fe8 100644 --- a/backend/oasst_backend/utils/ranking.py +++ b/backend/oasst_backend/utils/ranking.py @@ -96,13 +96,15 @@ def ranked_pairs(ranks: List[List[int]]): """ tallies, names = head_to_head_votes(ranks) tallies = tallies - tallies.T - # print(tallies) # note: the resulting tally matrix should be skew-symmetric # order by strength of victory (using tideman's original method, don't think it would make a difference for us) sorted_majorities = [] for i in range(len(ranks[0])): for j in range(len(ranks[0])): - if tallies[i, j] > 0: + # you can never prefer yourself over yourself + # we also have to pick one of the two choices, + # if the preference is exactly zero... + if tallies[i, j] >= 0 and i != j: sorted_majorities.append((i, j, tallies[i, j])) # we don't explicitly deal with tied majorities here sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True)) @@ -128,13 +130,36 @@ def ranked_pairs(ranks: List[List[int]]): if __name__ == "__main__": - ranks = ( + + ranks = """ ( [("w", "x", "z", "y") for _ in range(1)] + [("w", "y", "x", "z") for _ in range(2)] # + [("x","y","z","w") for _ in range(4)] + [("x", "z", "w", "y") for _ in range(5)] + [("y", "w", "x", "z") for _ in range(1)] # [("y","z","w","x") for _ in range(1000)] - ) + )""" + ranks = [ + [ + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ], + [ + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ], + [ + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ], + [ + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ], + ] rp = ranked_pairs(ranks) print(rp) From c2fa476904552ceca5675568f7645cae22de26fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 15:29:54 +0100 Subject: [PATCH 075/111] Add user emoji augmentation for message queries (#937) * add disposition to text labeling tasks * add emoji stats to ConversationMessage * add user emoji augmentation for message queries * add auth_method,username to message queries (query emoji status) * add auth_method+username for single message * fix param name typo * only join rows when message.emojis != JSON.NULL * formatting * make sure emojis and user_emojis default to {}, [] * remove init_user(), use fresh empty default collections --- .../oasst_backend/api/v1/frontend_users.py | 2 +- backend/oasst_backend/api/v1/messages.py | 60 ++++++++++++----- backend/oasst_backend/api/v1/users.py | 2 +- backend/oasst_backend/api/v1/utils.py | 5 +- backend/oasst_backend/models/message.py | 17 ++++- backend/oasst_backend/prompt_repository.py | 67 ++++++++++++++++--- backend/oasst_backend/tree_manager.py | 8 +++ oasst-shared/oasst_shared/schemas/protocol.py | 28 +++++--- 8 files changed, 149 insertions(+), 40 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 5ea7b26c..86f78026 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -77,7 +77,7 @@ def query_frontend_user_messages( """ Query frontend user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index af3ae42d..b3aace40 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -34,7 +34,7 @@ def query_messages( """ Query messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, @@ -93,7 +93,7 @@ def get_messages_cursor( qry_max_count = max_count + 1 if before is None or after is None else max_count - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id) items = pr.query_messages_ordered_by_created_date( user_id=user_id, auth_method=auth_method, @@ -137,37 +137,49 @@ def get_messages_cursor( @router.get("/{message_id}", response_model=protocol.Message) def get_message( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a message by its internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) return utils.prepare_message(message) @router.get("/{message_id}/conversation", response_model=protocol.Conversation) def get_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a conversation from the tree root and up to the message with given internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_conversation(message_id) return utils.prepare_conversation(messages) @router.get("/{message_id}/tree", response_model=protocol.MessageTree) def get_tree( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False) return utils.prepare_tree(tree, message.message_tree_id) @@ -175,24 +187,32 @@ def get_tree( @router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_children(message_id) return utils.prepare_message_list(messages) @router.get("/{message_id}/descendants", response_model=protocol.MessageTree) def get_descendants( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a subtree which starts with this message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -200,12 +220,16 @@ def get_descendants( @router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation) def get_longest_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get the longest conversation from the tree of the message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -213,12 +237,16 @@ def get_longest_conv( @router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree) def get_max_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get message with the most children from the tree of the provided message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index c0055339..d7497610 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -230,7 +230,7 @@ def query_user_messages( """ Query user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, user_id=user_id) messages = pr.query_messages_ordered_by_created_date( user_id=user_id, api_client_id=api_client_id, diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 8b0f378f..5c9537a3 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message: lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, - emojis=m.emojis, + emojis=m.emojis or {}, + user_emojis=m.user_emojis or [], ) @@ -30,6 +31,8 @@ def prepare_conversation_message_list(messages: list[Message]) -> list[protocol. text=message.text, lang=message.lang, is_assistant=(message.role == "assistant"), + emojis=message.emojis or {}, + user_emojis=message.user_emojis or [], ) for message in messages ] diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index da0c06c3..5f323d5d 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,12 +1,13 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import Any, Optional from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from pydantic import PrivateAttr from sqlalchemy import false from sqlmodel import Field, Index, SQLModel @@ -17,6 +18,13 @@ class Message(SQLModel, table=True): __tablename__ = "message" __table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),) + def __new__(cls, *args: Any, **kwargs: Any): + new_object = super().__new__(cls, *args, **kwargs) + # temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged + if not hasattr(new_object, "_user_emojis"): + new_object._init_private_attributes() + return new_object + id: Optional[UUID] = Field( sa_column=sa.Column( pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") @@ -49,7 +57,8 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) - emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) + emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False) + _user_emojis: Optional[list[str]] = PrivateAttr(default=None) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): @@ -59,3 +68,7 @@ class Message(SQLModel, table=True): def text(self) -> str: self.ensure_is_message() return self.payload.payload.text + + @property + def user_emojis(self) -> str: + return self._user_emojis diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7dddb5cf..b31b53d7 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -30,8 +30,9 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats from oasst_shared.utils import unaware_to_utc +from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified -from sqlmodel import Session, and_, func, not_, or_, text, update +from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -41,14 +42,25 @@ class PromptRepository: db: Session, api_client: ApiClient, client_user: Optional[protocol_schema.User] = None, + *, user_repository: Optional[UserRepository] = None, task_repository: Optional[TaskRepository] = None, + user_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + username: Optional[str] = None, ): self.db = db self.api_client = api_client self.user_repository = user_repository or UserRepository(db, api_client) - self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) - self.user_id = self.user.id if self.user else None + if user_id: + self.user = self.user_repository.get_user(id=user_id) + self.user_id = self.user.id + elif auth_method and username: + self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username) + self.user_id = self.user.id + else: + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) + self.user_id = self.user.id if self.user else None logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})") self.task_repository = task_repository or TaskRepository( db, api_client, client_user, user_repository=self.user_repository @@ -529,7 +541,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_user_message_trees( self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False @@ -539,7 +551,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]: qry = self.db.query(MessageTreeState).filter( @@ -582,6 +594,10 @@ class PromptRepository: return conversation, replies def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: + qry = self.db.query(Message).filter(Message.id == message_id) + messages = self._add_user_emojis_all(qry) + message = messages[0] if messages else None + message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) @@ -656,7 +672,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if exclude_deleted: qry = qry.filter(Message.deleted == sa.false()) - children = qry.all() + children = self._add_user_emojis_all(qry) return children def fetch_message_siblings( @@ -674,7 +690,7 @@ class PromptRepository: qry = qry.filter(Message.review_result == reviewed) if deleted is not None: qry = qry.filter(Message.deleted == deleted) - siblings = qry.all() + siblings = self._add_user_emojis_all(qry) return siblings @staticmethod @@ -705,7 +721,7 @@ class PromptRepository: if max_depth is not None: desc = desc.filter(Message.depth <= max_depth) - desc = desc.all() + desc = self._add_user_emojis_all(desc) return self.trace_descendants(message, desc) @@ -719,6 +735,33 @@ class PromptRepository: max_message = max(tree, key=lambda m: m.children_count) return max_message, [m for m in tree if m.parent_id == max_message.id] + def _add_user_emojis_all(self, qry: Query) -> list[Message]: + if self.user_id is None: + return qry.all() + + sq = qry.subquery("m") + qry = ( + self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis")) + .select_entity_from(sq) + .outerjoin( + MessageEmoji, + and_( + sq.c.id == MessageEmoji.message_id, + MessageEmoji.user_id == self.user_id, + sq.c.emojis != JSON.NULL, + ), + ) + .group_by(sq) + ) + messages: list[Message] = [] + for x in qry: + m: Message = x.Message + user_emojis = x["user_emojis"] + if user_emojis: + m._user_emojis = user_emojis.split(",") + messages.append(m) + return messages + def query_messages_ordered_by_created_date( self, user_id: Optional[UUID] = None, @@ -801,7 +844,7 @@ class PromptRepository: if lang is not None: qry = qry.filter(Message.lang == lang) - return qry.all() + return self._add_user_emojis_all(qry) def update_children_counts(self, message_tree_id: UUID): sql_update_children_count = """ @@ -902,9 +945,15 @@ WHERE message.id = cc.id; else: count = emoji_counts.get(emoji.value) or 0 emoji_counts[emoji.value] = count + 1 + if message._user_emojis is None: + message._user_emojis = [] + if emoji.value not in message._user_emojis: + message._user_emojis.append(emoji.value) elif op == protocol_schema.EmojiOp.remove: # remove emoji record and & decrement count message = self.fetch_message(message_id) + if message._user_emojis and emoji.value in message._user_emojis: + message._user_emojis.remove(emoji.value) self.db.delete(existing_emoji) emoji_counts = message.emojis count = emoji_counts.get(emoji.value) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 64e1883e..929a9297 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -354,6 +354,7 @@ class TreeManager: self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if message.role == "assistant": @@ -363,6 +364,8 @@ class TreeManager: ): valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, @@ -371,6 +374,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, + disposition=label_disposition, ) else: if ( @@ -387,6 +391,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id @@ -424,11 +429,13 @@ class TreeManager: message = random.choice(prompts_need_review) label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if random.random() > self.cfg.p_full_labeling_review_prompt: valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).") task = protocol_schema.LabelInitialPromptTask( @@ -437,6 +444,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index de431d75..31caa340 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -57,6 +57,8 @@ class ConversationMessage(BaseModel): text: str lang: Optional[str] # BCP 47 is_assistant: bool + emojis: Optional[dict[str, int]] = None + user_emojis: Optional[list[str]] = None class Conversation(BaseModel): @@ -80,7 +82,6 @@ class Conversation(BaseModel): class Message(ConversationMessage): parent_id: Optional[UUID] = None created_date: Optional[datetime] = None - emojis: Optional[dict] = None class MessagePage(PageResult): @@ -223,27 +224,34 @@ class LabelTaskMode(str, enum.Enum): full = "full" -class LabelInitialPromptTask(Task): - """A task to label an initial prompt.""" +class LabelTaskDisposition(str, enum.Enum): + """Reason why the task was issued.""" - type: Literal["label_initial_prompt"] = "label_initial_prompt" + quality = "quality" + spam = "spam" + + +class AbstractLabelTask(Task): message_id: UUID - prompt: str valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] + disposition: Optional[LabelTaskDisposition] -class LabelConversationReplyTask(Task): +class LabelInitialPromptTask(AbstractLabelTask): + """A task to label an initial prompt.""" + + type: Literal["label_initial_prompt"] = "label_initial_prompt" + prompt: str + + +class LabelConversationReplyTask(AbstractLabelTask): """A task to label a reply to a conversation.""" type: Literal["label_conversation_reply"] = "label_conversation_reply" conversation: Conversation # the conversation so far - message_id: UUID reply: str - valid_labels: list[str] - mandatory_labels: Optional[list[str]] - mode: Optional[LabelTaskMode] class LabelPrompterReplyTask(LabelConversationReplyTask): From d4688835d54a0da51c986cf2b43cfbe14fdfdb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 16:33:03 +0100 Subject: [PATCH 076/111] check condition for scoring on startup --- backend/oasst_backend/tree_manager.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 929a9297..a4cad0c5 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -542,9 +542,7 @@ class TreeManager: ) _, task = pr.store_ranking(interaction) - - ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id) - self.update_message_ranks(task.message_tree_id, rankings_by_message) + self.check_condition_for_scoring_state(task.message_tree_id) case protocol_schema.TextLabels: logger.info( @@ -659,7 +657,8 @@ class TreeManager: return False, None self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) - return True, rankings_by_message + self.update_message_ranks(message_tree_id, rankings_by_message) + return True def update_message_ranks( self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]] @@ -976,7 +975,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki return rankings_by_message @managed_tx_method(CommitMode.COMMIT) - def ensure_tree_states(self): + def ensure_tree_states(self) -> None: """Add message tree state rows for all root nodes (inital prompt messages).""" missing_tree_ids = self.query_misssing_tree_states() @@ -988,6 +987,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) + rankings = ( + self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all() + ) + if len(rankings) > 0: + logger.info(f"Checking state of {len(rankings)} message trees in ranking state.") + for r in rankings: + self.check_condition_for_scoring_state(r.message_tree_id) + def query_num_active_trees(self, lang: str) -> int: query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) From f1edcc8a285dbc184a14ab50dccc39d258d45c92 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 16:41:57 +0100 Subject: [PATCH 077/111] added streaming worker --- inference/README.md | 7 ++ inference/worker/__main__.py | 66 +++++++++++++------ inference/worker/requirements.txt | 4 +- .../oasst_shared/schemas/inference.py | 4 ++ 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/inference/README.md b/inference/README.md index 3dee94f9..bd0272ad 100644 --- a/inference/README.md +++ b/inference/README.md @@ -26,6 +26,13 @@ pip install -r requirements.txt python __main__.py ``` +For the worker, you'll also want to have the text-generation-inference server +running: + +```bash +docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference +``` + Run the client: ```bash diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index ad5e5cef..c8c1a4c9 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -1,13 +1,12 @@ -import re -import time +import json import rel -import torch +import requests +import sseclient import typer import websocket from loguru import logger from oasst_shared.schemas import inference, protocol -from transformers import pipeline app = typer.Typer() @@ -16,9 +15,8 @@ app = typer.Typer() def main( backend_url: str = "ws://localhost:8000", model_name: str = "distilgpt2", + inference_server_url: str = "http://localhost:8001", ): - pipe = pipeline("text-generation", model=model_name) - def on_open(ws: websocket.WebSocket): worker_config = inference.WorkerConfig(model_name=model_name) ws.send(worker_config.json()) @@ -37,23 +35,49 @@ def main( prompt = "\n".join(messages) + "\nAssistant:" - # TODO: replace this with incremental generation - torch.manual_seed(work_request.seed) - model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ - 0 - ]["generated_text"] - model_output = model_output.strip() + # TODO: use the seed + # torch.manual_seed(work_request.seed) + # model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ + # 0 + # ]["generated_text"] + # model_output = model_output.strip() - # fake streaming - split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] - pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] - for piece in pieces: - if not piece: - continue - if piece.strip() in ("User:", "Assistant:"): + # # fake streaming + # split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] + # pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] + # for piece in pieces: + # if not piece: + # continue + # if piece.strip() in ("User:", "Assistant:"): + # break + # ws.send(inference.WorkResponsePacket(token=piece).json()) + # time.sleep(0.1) + # ws.send(inference.WorkResponsePacket(is_end=True).json()) + + response = requests.post( + f"{inference_server_url}/generate_stream", + json={ + "inputs": prompt, + "parameters": { + "max_new_tokens": work_request.max_new_tokens, + "do_sample": work_request.do_sample, + "top_k": work_request.top_k, + "top_p": work_request.top_p, + "temperature": work_request.temperature, + }, + }, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() + + client = sseclient.SSEClient(response) + for event in client.events(): + data = json.loads(event.data) + if data["is_end"]: break - ws.send(inference.WorkResponsePacket(token=piece).json()) - time.sleep(0.1) + intermediate = data["event"] + ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json()) ws.send(inference.WorkResponsePacket(is_end=True).json()) def on_error(ws: websocket.WebSocket, error: Exception): diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt index c248c652..82169379 100644 --- a/inference/worker/requirements.txt +++ b/inference/worker/requirements.txt @@ -1,6 +1,6 @@ loguru rel -torch -transformers +requests +sseclient-py typer websocket-client diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 0acb5014..b50cef9c 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -14,6 +14,10 @@ class WorkRequest(pydantic.BaseModel): model_name: str = "distilgpt2" max_new_tokens: int = 100 seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + do_sample: bool = True + top_k: int = 50 + top_p: float = 0.9 + temperature: float = 1.0 class WorkResponsePacket(pydantic.BaseModel): From 348999a93636bfcadb72256175a0a7115ff0cf63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 19:06:25 +0100 Subject: [PATCH 078/111] exclude trees in ranking state in acitve tree count --- backend/oasst_backend/tree_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a4cad0c5..992e75dd 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -202,7 +202,7 @@ class TreeManager: lang = "en" logger.warning("Task availability request without lang tag received, assuming lang='en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) extendible_parents = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) @@ -230,7 +230,7 @@ class TreeManager: lang = "en" logger.warning("Task request without lang tag received, assuming 'en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) extendible_parents = self.query_extendible_parents(lang=lang) @@ -995,12 +995,15 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki for r in rankings: self.check_condition_for_scoring_state(r.message_tree_id) - def query_num_active_trees(self, lang: str) -> int: + def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int: + """Count all active trees (optionally exclude those in ranking state).""" query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) .join(Message, MessageTreeState.message_tree_id == Message.id) .filter(MessageTreeState.active, Message.lang == lang) ) + if exclude_ranking: + query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: From 040344a41f8beef9c0a3ad4e82474326186b121f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 21:01:52 +0100 Subject: [PATCH 079/111] made inference server a bit more robust --- inference/server/main.py | 88 ++++++++++++++++++++++-------------- inference/worker/__main__.py | 8 +++- 2 files changed, 60 insertions(+), 36 deletions(-) diff --git a/inference/server/main.py b/inference/server/main.py index f3ec02b1..4cb5f659 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -5,6 +5,7 @@ import uuid import fastapi import pydantic import redis.asyncio as redis +import websockets.exceptions from fastapi.middleware.cors import CORSMiddleware from loguru import logger from oasst_shared.schemas import inference, protocol @@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum): pending = "pending" in_progress = "in_progress" complete = "complete" + aborted_by_worker = "aborted_by_worker" class DbChatEntry(pydantic.BaseModel): @@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque async def work(websocket: fastapi.WebSocket): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) - while True: - # find a pending task that matches the worker's config - # could also be implemented using task queues - # but general compatibility matching is tricky - for chat in CHATS.values(): - if (request := chat.pending_message_request) is not None: - if chat.message_request_state == MessageRequestState.pending: - if request.compatible_with(worker_config): + try: + while True: + print(websocket.client_state) + if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: + logger.warning("Worker disconnected") + break + # find a pending task that matches the worker's config + # could also be implemented using task queues + # but general compatibility matching is tricky + for chat in CHATS.values(): + if (request := chat.pending_message_request) is not None: + if chat.message_request_state == MessageRequestState.pending: + if request.compatible_with(worker_config): + break + else: + logger.debug("No pending tasks") + await asyncio.sleep(1) + continue + + chat.message_request_state = MessageRequestState.in_progress + + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) + + logger.info(f"Created {work_request}") + try: + await websocket.send_text(work_request.json()) + except websockets.exceptions.ConnectionClosedError: + logger.warning("Worker disconnected") + websocket.close() + chat.message_request_state = MessageRequestState.pending + break + + try: + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + await redisClient.rpush(chat.id, response_packet.json()) + if response_packet.is_end: break - else: - logger.debug("No pending tasks") - await asyncio.sleep(1) - continue + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + chat.message_request_state = MessageRequestState.aborted_by_worker + raise - chat.message_request_state = MessageRequestState.in_progress - - work_request = inference.WorkRequest( - conversation=chat.conversation, - model_name=request.model_name, - max_new_tokens=request.max_new_tokens, - ) - - logger.info(f"Created {work_request}") - try: - await websocket.send_text(work_request.json()) - while True: - # maybe unnecessary to parse and re-serialize - # could just pass the raw string and mark end via empty string - response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) - await redisClient.rpush(chat.id, response_packet.json()) - if response_packet.is_end: - break - except fastapi.WebSocketException: - # TODO: handle this better - logger.exception(f"Websocket closed during handling of {chat.id}") - - chat.message_request_state = MessageRequestState.complete + chat.message_request_state = MessageRequestState.complete + except fastapi.WebSocketException: + logger.exception("Websocket closed") diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index c8c1a4c9..e5c15fb4 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -33,7 +33,13 @@ def main( # construct prompt messages = [_prepare_message(message) for message in work_request.conversation.messages] - prompt = "\n".join(messages) + "\nAssistant:" + prefix = ( + "The following is a conversation between a user and an assistant. " + "The assistant is helpful, creative, clever, and very friendly.\n" + "Assistant: Hello! How can I help you today?\n" + ) + + prompt = prefix + "\n".join(messages) + "\nAssistant:" # TODO: use the seed # torch.manual_seed(work_request.seed) From f3ffde47ffc1cf2f49a8f835cf2c1a4a38ebcc9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 23:00:54 +0100 Subject: [PATCH 080/111] add preferred lonely_children extension (#942) * add preferred lonely_children extension * simplify sampling process, lower the probability to 25% * exclude parents for replies that were recently used * lonely children := count > 0 * consider only tasks not done for parent exclusion * increase lonely child sampling probability --- ...84fcd6900dc_add_task_created_date_index.py | 26 ++++++++++++++ backend/oasst_backend/config.py | 9 +++++ backend/oasst_backend/models/task.py | 4 ++- backend/oasst_backend/task_repository.py | 16 ++++++++- backend/oasst_backend/tree_manager.py | 34 ++++++++++++++++--- 5 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py diff --git a/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py new file mode 100644 index 00000000..29fd1aec --- /dev/null +++ b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py @@ -0,0 +1,26 @@ +"""add task created date index + +Revision ID: c84fcd6900dc +Revises: 40ed93df0ed5 +Create Date: 2023-01-26 18:35:43.061589 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c84fcd6900dc" +down_revision = "40ed93df0ed5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_task_created_date"), table_name="task") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index d831eca2..9952c654 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -57,6 +57,15 @@ class TreeManagerConfiguration(BaseModel): rank_prompter_replies: bool = False + lonely_children_count: int = 3 + """Number of children below which parents are preferred during sampling for reply tasks.""" + + p_lonely_child_extension: float = 0.8 + """Probability to select a parent with less than lonely_children_count children.""" + + recent_tasks_span_sec: int = 3 * 60 # 3 min + """Time in seconds of recent tasks to consider for exclusion during task selection.""" + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index a59f689e..7f91b157 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -20,7 +20,9 @@ class Task(SQLModel, table=True): ), ) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()), + sa_column=sa.Column( + sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp() + ), ) expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index eb100fe3..5c5dea21 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Optional from uuid import UUID @@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, func, or_ from starlette.status import HTTP_404_NOT_FOUND @@ -219,3 +220,16 @@ class TaskRepository: def fetch_task_by_id(self, task_id: UUID) -> Task: task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none() return task + + def fetch_recent_reply_tasks( + self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100 + ) -> list[Task]: + qry = self.db.query(Task).filter( + func.age(Task.created_date) < max_age, + or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"), + ) + if done is not None: + qry = qry.filter(Task.done == done) + if limit: + qry = qry.limit(limit) + return qry.all() diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 992e75dd..89a51807 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -1,7 +1,7 @@ import json import random import sys -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple @@ -339,6 +339,7 @@ class TreeManager: message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_REPLY: + if task_role == TaskRole.PROMPTER: replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review)) elif task_role == TaskRole.ASSISTANT: @@ -398,19 +399,44 @@ class TreeManager: message_tree_id = message.message_tree_id case TaskType.REPLY: - # select a tree with missing replies + + recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks( + max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False + ) + recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks} + if task_role == TaskRole.PROMPTER: extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) elif task_role == TaskRole.ASSISTANT: extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents)) + # select a tree with missing replies if len(extendible_parents) > 0: - random_parent = random.choice(extendible_parents) + random_parent: ExtendibleParentRow = None + if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1: + # check if we have extendible parents with a small number of replies + + lonely_children_parents = [ + p + for p in extendible_parents + if 0 < p.active_children_count < self.cfg.lonely_children_count + and p.parent_id not in recent_reply_task_parents + ] + if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension: + random_parent = random.choice(lonely_children_parents) + + if random_parent is None: + # try to exclude parents for which tasks were recently handed out + fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents] + if len(fresh_parents) > 0: + random_parent = random.choice(fresh_parents) + else: + random_parent = random.choice(extendible_parents) # fetch random conversation to extend logger.debug(f"selected {random_parent=}") messages = self.pr.fetch_message_conversation(random_parent.parent_id) - assert all(m.review_result for m in messages) # ensure all messages have positive review + assert all(m.review_result for m in messages) # ensure all messages have positive reviews conversation = prepare_conversation(messages) # generate reply task depending on last message From da1c81d2c9ffaf748a94dddce1de8ea4a014b341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 00:54:29 +0100 Subject: [PATCH 081/111] Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947) * add LabelDescription list to labeling tasks * make +1 & -1 emoji exclusive (only one of both or none) * add red_flag emoji to message when reported * fix task's valid labels * fix typo --- backend/oasst_backend/api/v1/text_labels.py | 29 +++++++++- backend/oasst_backend/config.py | 51 ++++++++++++++-- backend/oasst_backend/models/message.py | 6 ++ backend/oasst_backend/prompt_repository.py | 35 ++++++++--- backend/oasst_backend/schemas/text_labels.py | 11 +--- backend/oasst_backend/tree_manager.py | 33 +++++++---- oasst-shared/oasst_shared/schemas/protocol.py | 58 ++++++++++++------- 7 files changed, 169 insertions(+), 54 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index dc6cc889..2025fd4c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse +from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -45,7 +46,29 @@ def label_text( def get_valid_lables() -> ValidLabelsResponse: return ValidLabelsResponse( valid_labels=[ - LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text) - for l in protocol_schema.TextLabel + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in TextLabel + ] + ) + + +@router.get("/report_labels") +def get_report_lables() -> ValidLabelsResponse: + report_labels = [ + TextLabel.spam, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + TextLabel.moral_judgement, + TextLabel.political_content, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.quality, + ] + return ValidLabelsResponse( + valid_labels=[ + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in report_labels ] ) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 9952c654..157845d7 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator @@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" - mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + labels_initial_prompt: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_assistant_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.fails_task, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_prompter_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.humor, + TextLabel.creativity, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for initial prompts.""" - mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for assistant replies.""" - mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" rank_prompter_replies: bool = False diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 5f323d5d..24fafc01 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -64,6 +64,12 @@ class Message(SQLModel, table=True): if not self.payload or not isinstance(self.payload.payload, MessagePayload): raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) + def has_emoji(self, emoji_code: str) -> bool: + return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0 + + def has_user_emoji(self, emoji_code: str) -> bool: + return self._user_emojis and emoji_code in self._user_emojis + @property def text(self) -> str: self.ensure_is_message() diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b31b53d7..d3f655fc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -461,15 +461,22 @@ class PromptRepository: ) if message_id: - message = self.fetch_message(message_id) - if task: + if not task: + if text_labels.is_report is True: + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag + ) + + # update existing record for repeated updates (same user no task associated) + existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) + if existing_text_label is not None: + existing_text_label.labels = text_labels.labels + model = existing_text_label + + else: + message = self.fetch_message(message_id) message.review_count += 1 self.db.add(message) - # for the same User id with no task id associated with the message, then update existing record for repeated updates - existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) - if existing_text_label is not None: - existing_text_label.labels = text_labels.labels - model = existing_text_label self.db.add(model) return model, task, message @@ -936,6 +943,20 @@ WHERE message.id = cc.id; op = protocol_schema.EmojiOp.add if op == protocol_schema.EmojiOp.add: + # hard coded exclusivity of thumbs_up & thumbs_down + if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_down.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down + ) + elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_up.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up + ) + # insert emoji record & increment count message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji) self.db.add(message_emoji) diff --git a/backend/oasst_backend/schemas/text_labels.py b/backend/oasst_backend/schemas/text_labels.py index 9135c558..e846d8f4 100644 --- a/backend/oasst_backend/schemas/text_labels.py +++ b/backend/oasst_backend/schemas/text_labels.py @@ -1,13 +1,6 @@ -from typing import Optional - +from oasst_shared.schemas.protocol import LabelDescription from pydantic import BaseModel -class LabelOption(BaseModel): - name: str - display_text: str - help_text: Optional[str] - - class ValidLabelsResponse(BaseModel): - valid_labels: list[LabelOption] + valid_labels: list[LabelDescription] diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 89a51807..77419184 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel): class TreeManager: - _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) - def __init__( self, db: Session, @@ -216,6 +214,15 @@ class TreeManager: incomplete_rankings=incomplete_rankings, ) + @staticmethod + def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]: + return [ + protocol_schema.LabelDescription( + name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text + ) + for l in valid_labels + ] + def next_task( self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random, @@ -356,14 +363,14 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels if message.role == "assistant": + valid_labels = self.cfg.labels_assistant_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_assistant ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + valid_labels = self.cfg.mandatory_labels_assistant_reply label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -372,27 +379,30 @@ class TreeManager: message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) else: + valid_labels = self.cfg.labels_prompter_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_prompter ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + valid_labels = self.cfg.mandatory_labels_prompter_reply label_mode = protocol_schema.LabelTaskMode.simple logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -456,10 +466,10 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels + valid_labels = self.cfg.labels_initial_prompt if random.random() > self.cfg.p_full_labeling_review_prompt: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) + valid_labels = self.cfg.mandatory_labels_initial_prompt label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -467,10 +477,11 @@ class TreeManager: task = protocol_schema.LabelInitialPromptTask( message_id=message.id, prompt=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -577,7 +588,7 @@ class TreeManager: _, task, msg = pr.store_text_labels(interaction) - # if it was a respones for a task, check if we have enough reviews to calc review_result + # if it was a response for a task, check if we have enough reviews to calc review_result if task and msg: reviews = self.query_reviews_for_message(msg.id) acceptance_score = self._calculate_acceptance(reviews) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 31caa340..22a4adfb 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -231,12 +231,20 @@ class LabelTaskDisposition(str, enum.Enum): spam = "spam" +class LabelDescription(BaseModel): + name: str + widget: str + display_text: str + help_text: Optional[str] + + class AbstractLabelTask(Task): message_id: UUID valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] disposition: Optional[LabelTaskDisposition] + labels: Optional[list[LabelDescription]] class LabelInitialPromptTask(AbstractLabelTask): @@ -324,39 +332,48 @@ class MessageRanking(Interaction): ranking: conlist(item_type=int, min_items=1) +class LabelWidget(str, enum.Enum): + yes_no = "yes_no" + flag = "flag" + likert = "likert" + + class TextLabel(str, enum.Enum): """A label for a piece of text.""" - def __new__(cls, label: str, display_text: str = "", help_text: str = None): + def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None): obj = str.__new__(cls, label) obj._value_ = label + obj.widget = widget obj.display_text = display_text obj.help_text = help_text return obj - spam = "spam", "Seems to be intentionally low-quality or irrelevant" - fails_task = "fails_task", "Fails to follow the correct instruction / task" - not_appropriate = "not_appropriate", "Inappropriate for customer assistant" - violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm" - excessive_harm = ( - "excessive_harm", - "Content likely to cause excessive harm not justifiable in the context", - "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", - ) - sexual_content = "sexual_content", "Contains sexual content" - toxicity = "toxicity", "Contains rude, abusive, profane or insulting content" - moral_judgement = "moral_judgement", "Expresses moral judgement" - political_content = "political_content", "Expresses political views" - humor = "humor", "Contains humorous content including sarcasm" + # yes/no questions + spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant" + fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task" + + # flags + pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)" + not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate" hate_speech = ( "hate_speech", + LabelWidget.flag, "Content is abusive or threatening and expresses prejudice against a protected characteristic", - "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + "Prejudice refers to preconceived views not based on reason. Protected characteristics " + "include gender, ethnicity, religion, sexual orientation, and similar characteristics.", ) - threat = "threat", "Contains a threat against a person or persons" - misleading = "misleading", "Contains text which is incorrect or misleading" - helpful = "helpful", "Completes the task to a high standard" - creative = "creative", "Expresses creativity in responding to the task" + sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content" + moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement" + political_content = "political_content", LabelWidget.flag, "Expresses political views" + + # likert + quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message" + toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content" + humor = "humor", LabelWidget.likert, "Humorous content including sarcasm" + helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message" + creativity = "creativity", LabelWidget.likert, "Creativity" + violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm" class TextLabels(Interaction): @@ -367,6 +384,7 @@ class TextLabels(Interaction): labels: dict[TextLabel, float] message_id: UUID task_id: Optional[UUID] + is_report: Optional[bool] @property def has_message_id(self) -> bool: From 1011fa527c97e7b0f04aea905a588279177420fa Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Fri, 27 Jan 2023 13:25:07 +1100 Subject: [PATCH 082/111] website: Fix handling of new and unknown labels --- .../src/components/Survey/LabelInputGroup.tsx | 288 ++++++++++-------- 1 file changed, 167 insertions(+), 121 deletions(-) diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelInputGroup.tsx index 06e0888b..9a06b6ca 100644 --- a/website/src/components/Survey/LabelInputGroup.tsx +++ b/website/src/components/Survey/LabelInputGroup.tsx @@ -18,126 +18,172 @@ interface LabelInfo { inverted: boolean; } -// This should be moved to the valid labels api endpoint -const label_messages: { - [label: string]: LabelInfo; -} = { - spam: { - zeroText: "Not Spam", - zeroDescription: ["Suitable for training Open Assistant."], - oneText: "Spam", - oneDescription: [ - "Seems to be intentionally low-quality or irrelevant", - 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', - "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", - "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", - ], - inverted: true, - }, - fails_task: { - zeroText: "Follows Instructions", - zeroDescription: [], - oneText: "Fails Task", - oneDescription: ["Fails to follow the correct instruction / task"], - inverted: true, - }, - not_appropriate: { - zeroText: "Appropriate", - zeroDescription: [], - oneText: "Inappropriate", - oneDescription: ["Inappropriate for customer assistant"], - inverted: true, - }, - violence: { - zeroText: "Harmless", - zeroDescription: [], - oneText: "Violent", - oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], - inverted: true, - }, - excessive_harm: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Harmful", - oneDescription: [ - "Content likely to cause excessive harm not justifiable in the context", - "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", - ], - inverted: true, - }, - sexual_content: { - zeroText: "Non Sexual", - zeroDescription: [], - oneText: "Sexual", - oneDescription: ["Contains sexual content"], - inverted: true, - }, - toxicity: { - zeroText: "Polite", - zeroDescription: [], - oneText: "Rude", - oneDescription: ["Contains rude, abusive, profane or insulting content"], - inverted: true, - }, - moral_judgement: { - zeroText: "Non-Judgemental", - zeroDescription: [], - oneText: "Judgemental", - oneDescription: ["Expresses moral judgement"], - inverted: true, - }, - political_content: { - zeroText: "Apolitical", - zeroDescription: [], - oneText: "Political", - oneDescription: ["Expresses political views"], - inverted: true, - }, - humor: { - zeroText: "Serious", - zeroDescription: [], - oneText: "Humorous", - oneDescription: ["Contains humorous content including sarcasm"], - inverted: false, - }, - hate_speech: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Hateful", - oneDescription: [ - "Content is abusive or threatening and expresses prejudice against a protected characteristic", - "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", - ], - inverted: true, - }, - threat: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Threatening", - oneDescription: ["Contains a threat against a person or persons"], - inverted: true, - }, - misleading: { - zeroText: "Accurate", - zeroDescription: [], - oneText: "Misleading", - oneDescription: ["Contains text which is incorrect or misleading"], - inverted: true, - }, - helpful: { - zeroText: "Unhelful", - zeroDescription: [], - oneText: "Helpful", - oneDescription: ["Completes the task to a high standard"], - inverted: false, - }, - creative: { - zeroText: "Boring", - zeroDescription: [], - oneText: "Creative", - oneDescription: ["Expresses creativity in responding to the task"], - inverted: false, - }, +const getLabelInfo = (label: string): LabelInfo => { + switch (label) { + case "spam": + return { + zeroText: "Not Spam", + zeroDescription: ["Suitable for training Open Assistant."], + oneText: "Spam", + oneDescription: [ + "Seems to be intentionally low-quality or irrelevant", + 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', + "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", + "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", + ], + inverted: true, + }; + case "fails_task": + return { + zeroText: "Follows Instructions", + zeroDescription: [], + oneText: "Fails Task", + oneDescription: ["Fails to follow the correct instruction / task"], + inverted: true, + }; + case "not_appropriate": + return { + zeroText: "Appropriate", + zeroDescription: [], + oneText: "Inappropriate", + oneDescription: ["Inappropriate for customer assistant"], + inverted: true, + }; + case "violence": + return { + zeroText: "Harmless", + zeroDescription: [], + oneText: "Violent", + oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], + inverted: true, + }; + case "excessive_harm": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Harmful", + oneDescription: [ + "Content likely to cause excessive harm not justifiable in the context", + "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", + ], + inverted: true, + }; + case "sexual_content": + return { + zeroText: "Non Sexual", + zeroDescription: [], + oneText: "Sexual", + oneDescription: ["Contains sexual content"], + inverted: true, + }; + case "toxicity": + return { + zeroText: "Polite", + zeroDescription: [], + oneText: "Rude", + oneDescription: ["Contains rude, abusive, profane or insulting content"], + inverted: true, + }; + case "moral_judgement": + return { + zeroText: "Non-Judgemental", + zeroDescription: [], + oneText: "Judgemental", + oneDescription: ["Expresses moral judgement"], + inverted: true, + }; + case "political_content": + return { + zeroText: "Apolitical", + zeroDescription: [], + oneText: "Political", + oneDescription: ["Expresses political views"], + inverted: true, + }; + case "humor": + return { + zeroText: "Serious", + zeroDescription: [], + oneText: "Humorous", + oneDescription: ["Contains humorous content including sarcasm"], + inverted: false, + }; + case "hate_speech": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Hateful", + oneDescription: [ + "Content is abusive or threatening and expresses prejudice against a protected characteristic", + "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + ], + inverted: true, + }; + case "threat": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Threatening", + oneDescription: ["Contains a threat against a person or persons"], + inverted: true, + }; + case "misleading": + return { + zeroText: "Accurate", + zeroDescription: [], + oneText: "Misleading", + oneDescription: ["Contains text which is incorrect or misleading"], + inverted: true, + }; + case "helpful": + return { + zeroText: "Unhelful", + zeroDescription: [], + oneText: "Helpful", + oneDescription: ["Completes the task to a high standard"], + inverted: false, + }; + case "creative": + return { + zeroText: "Boring", + zeroDescription: [], + oneText: "Creative", + oneDescription: ["Expresses creativity in responding to the task"], + inverted: false, + }; + case "pii": + return { + zeroText: "Clean", + zeroDescription: [], + oneText: "Contains PII", + oneDescription: ["Contains personally identifing information"], + inverted: false, + }; + case "quality": + return { + zeroText: "Low Quality", + zeroDescription: [], + oneText: "High Quality", + oneDescription: [], + inverted: false, + }; + case "creativity": + return { + zeroText: "Ordinary", + zeroDescription: [], + oneText: "Creative", + oneDescription: [], + inverted: false, + }; + default: + return { + zeroText: `!${label}`, + zeroDescription: [], + oneText: label, + oneDescription: [], + inverted: false, + }; + } }; export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => { @@ -148,7 +194,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label return ( {labelIDs.map((labelId, idx) => { - const { zeroText, oneText, zeroDescription, oneDescription, inverted } = label_messages[labelId]; + const { zeroText, oneText, zeroDescription, oneDescription, inverted } = getLabelInfo(labelId); let textA = zeroText; let textB = oneText; From 24f4c0879626c45b9c94c034eca9b06be829c8c0 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Thu, 26 Jan 2023 13:13:46 +0000 Subject: [PATCH 083/111] Refactor task page routes and repetition Remove blank line Lint add blank line Link pre-commit --- .../e2e/tasks/no_tasks_available.cy.ts | 23 ++++++++++ website/public/locales/en/common.json | 5 ++- website/src/components/EmptyState.tsx | 2 +- website/src/components/TaskPage/TaskPage.tsx | 40 ++++++++++++++++++ website/src/lib/api.ts | 3 +- website/src/lib/constants.ts | 42 +++++++++++++++++++ website/src/pages/create/assistant_reply.tsx | 29 ++----------- website/src/pages/create/initial_prompt.tsx | 29 ++----------- website/src/pages/create/user_reply.tsx | 33 +++------------ .../pages/evaluate/rank_assistant_replies.tsx | 29 ++----------- .../pages/evaluate/rank_initial_prompts.tsx | 29 ++----------- .../src/pages/evaluate/rank_user_replies.tsx | 33 +++------------ .../src/pages/label/label_assistant_reply.tsx | 29 ++----------- .../src/pages/label/label_initial_prompt.tsx | 29 ++----------- .../src/pages/label/label_prompter_reply.tsx | 29 ++----------- website/src/pages/tasks/random.tsx | 32 ++------------ website/src/types/Task.ts | 2 +- 17 files changed, 147 insertions(+), 271 deletions(-) create mode 100644 website/cypress/e2e/tasks/no_tasks_available.cy.ts create mode 100644 website/src/components/TaskPage/TaskPage.tsx create mode 100644 website/src/lib/constants.ts diff --git a/website/cypress/e2e/tasks/no_tasks_available.cy.ts b/website/cypress/e2e/tasks/no_tasks_available.cy.ts new file mode 100644 index 00000000..27c33b48 --- /dev/null +++ b/website/cypress/e2e/tasks/no_tasks_available.cy.ts @@ -0,0 +1,23 @@ +describe("no tasks available", () => { + it("displays an empty state when no tasks are available", () => { + cy.signInWithEmail("cypress@example.com"); + cy.intercept( + { + method: "GET", + url: "/api/new_task/prompter_reply", + }, + { + statusCode: 500, + body: { + message: "No tasks of type 'label_prompter_reply' are currently available.", + errorCode: 1006, + httpStatusCode: 503, + }, + } + ).as("newTaskPrompterReply"); + cy.visit("/create/user_reply"); + cy.wait("@newTaskPrompterReply").then(() => { + cy.get('[data-cy="cy-no-tasks"]').should("exist"); + }); + }); +}); diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 8f35eaab..0b0f9d37 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -9,11 +9,12 @@ "docs": "Docs", "github": "GitHub", "legal": "Legal", + "loading": "Loading...", + "more_information": "More Information", "privacy_policy": "Privacy Policy", "report_a_bug": "Report a Bug", "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", - "title": "Open Assistant", - "more_information": "More Information" + "title": "Open Assistant" } diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index b0455774..63d8d3bf 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -15,7 +15,7 @@ export const EmptyState = (props: EmptyStateProps) => { - {props.text} + {props.text} Go back to the dashboard diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx new file mode 100644 index 00000000..9fc26c42 --- /dev/null +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -0,0 +1,40 @@ +import Head from "next/head"; +import { useTranslation } from "next-i18next"; +import { TaskEmptyState } from "src/components/EmptyState"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Task } from "src/components/Tasks/Task"; +import { TaskInfos } from "src/components/Tasks/TaskTypes"; +import { apiHooksByType, ERROR_CODES } from "src/lib/constants"; +import { getTypeSafei18nKey } from "src/lib/i18n"; +import { TaskType } from "src/types/Task"; + +type TaskPageProps = { + type: TaskType; +}; + +export const TaskPage = ({ type }: TaskPageProps) => { + const { t } = useTranslation(["tasks", "common"]); + const apiHook = apiHooksByType[type]; + const { tasks, isLoading, reset, trigger, error } = apiHook(type); + const taskType = TaskInfos.find((taskType) => taskType.type === type); + + if (isLoading) { + return ; + } + + if (tasks.length === 0 || error?.errorCode === ERROR_CODES.TASK_REQUESTED_TYPE_NOT_AVAILABLE) { + return ; + } + + const task = tasks[0]; + + return ( + <> + + {t(getTypeSafei18nKey(`${taskType.id}.label`))} + + + + + ); +}; diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts index d61016d2..2649daf8 100644 --- a/website/src/lib/api.ts +++ b/website/src/lib/api.ts @@ -17,7 +17,8 @@ export const post = (url: string, { arg: data }) => api.post(url, data).then((re api.interceptors.response.use( (response) => response, (error) => { - throw new OasstError(error.message ?? error, error.error_code, error?.response?.status || -1); + const err = error?.response?.data; + throw new OasstError(err?.message ?? error, err?.errorCode, error?.response?.httpStatusCode || -1); } ); diff --git a/website/src/lib/constants.ts b/website/src/lib/constants.ts new file mode 100644 index 00000000..6269dbf9 --- /dev/null +++ b/website/src/lib/constants.ts @@ -0,0 +1,42 @@ +import { + useCreateAssistantReply, + useCreateInitialPrompt, + useCreatePrompterReply, +} from "src/hooks/tasks/useCreateReply"; +import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { + useLabelAssistantReplyTask, + useLabelInitialPromptTask, + useLabelPrompterReplyTask, +} from "src/hooks/tasks/useLabelingTask"; +import { + useRankAssistantRepliesTask, + useRankInitialPromptsTask, + useRankPrompterRepliesTask, +} from "src/hooks/tasks/useRankReplies"; +import { TaskType } from "src/types/Task"; + +export const ERROR_CODES = { + TASK_REQUESTED_TYPE_NOT_AVAILABLE: 1006, + TASK_INVALID_REQUEST_TYPE: 1000, + TASK_ACK_FAILED: 1001, + TASK_NACK_FAILED: 1002, + TASK_INVALID_RESPONSE_TYPE: 1003, + TASK_INTERACTION_REQUEST_FAILED: 1004, + TASK_GENERATION_FAILED: 1005, + TASK_AVAILABILITY_QUERY_FAILED: 1007, + TASK_MESSAGE_TOO_LONG: 1008, +}; + +export const apiHooksByType = { + [TaskType.random]: useGenericTaskAPI, + [TaskType.assistant_reply]: useCreateAssistantReply, + [TaskType.initial_prompt]: useCreateInitialPrompt, + [TaskType.label_assistant_reply]: useLabelAssistantReplyTask, + [TaskType.label_initial_prompt]: useLabelInitialPromptTask, + [TaskType.label_prompter_reply]: useLabelPrompterReplyTask, + [TaskType.prompter_reply]: useCreatePrompterReply, + [TaskType.rank_assistant_replies]: useRankAssistantRepliesTask, + [TaskType.rank_initial_prompts]: useRankInitialPromptsTask, + [TaskType.rank_prompter_replies]: useRankPrompterRepliesTask, +}; diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index 1c83eb23..0f3095a9 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const AssistantReply = () => { - const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Reply as Assistant - - - - - ); -}; +const AssistantReply = () => ; AssistantReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 639df68f..c73f2e5d 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const InitialPrompt = () => { - const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Initial Prompt - - - - - ); -}; +const InitialPrompt = () => ; InitialPrompt.getLayout = getDashboardLayout; diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 5898439c..39218476 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,33 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const UserReply = () => { - const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); +const PrompterReply = () => ; - if (isLoading) { - return ; - } +PrompterReply.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Reply as User - - - - - ); -}; - -UserReply.getLayout = getDashboardLayout; - -export default UserReply; +export default PrompterReply; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index da79d92f..dd4c1df9 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankAssistantReplies = () => { - const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank Assistant Replies - - - - - ); -}; +const RankAssistantReplies = () => ; RankAssistantReplies.getLayout = getDashboardLayout; diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index f23fc0ed..1eb91289 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankInitialPrompts = () => { - const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank Initial Prompts - - - - - ); -}; +const RankInitialPrompts = () => ; RankInitialPrompts.getLayout = getDashboardLayout; diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index cee82b87..a1caba59 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,33 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankUserReplies = () => { - const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); +const RankPrompterReplies = () => ; - if (isLoading) { - return ; - } +RankPrompterReplies.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank User Replies - - - - - ); -}; - -RankUserReplies.getLayout = getDashboardLayout; - -export default RankUserReplies; +export default RankPrompterReplies; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 07a6cb1c..8be12b41 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelAssistantReply = () => { - const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Assistant Reply - - - - - ); -}; +const LabelAssistantReply = () => ; LabelAssistantReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 8735044f..c5fed344 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelInitialPrompt = () => { - const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Initial Prompt - - - - - ); -}; +const LabelInitialPrompt = () => ; LabelInitialPrompt.getLayout = getDashboardLayout; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 17164e11..33e8aba4 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelPrompterReply = () => { - const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Prompter Reply - - - - - ); -}; +const LabelPrompterReply = () => ; LabelPrompterReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/tasks/random.tsx b/website/src/pages/tasks/random.tsx index f1c04d2c..cd7ed458 100644 --- a/website/src/pages/tasks/random.tsx +++ b/website/src/pages/tasks/random.tsx @@ -1,34 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import { TaskType } from "src/types/Task"; -const RandomTask = () => { - const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random); +const Random = () => ; - if (isLoading) { - return ; - } +Random.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Random Task - - - - - ); -}; - -RandomTask.getLayout = (page) => getDashboardLayout(page); - -export default RandomTask; +export default Random; diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 12e37db0..7ae48138 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -1,4 +1,4 @@ -export const enum TaskType { +export enum TaskType { initial_prompt = "initial_prompt", assistant_reply = "assistant_reply", prompter_reply = "prompter_reply", From e6009933db67de3d40cdbac87104858471dd7023 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 27 Jan 2023 10:16:25 +0000 Subject: [PATCH 084/111] Rename const to taskInfo --- website/src/components/TaskPage/TaskPage.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx index 9fc26c42..41d89405 100644 --- a/website/src/components/TaskPage/TaskPage.tsx +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -16,7 +16,7 @@ export const TaskPage = ({ type }: TaskPageProps) => { const { t } = useTranslation(["tasks", "common"]); const apiHook = apiHooksByType[type]; const { tasks, isLoading, reset, trigger, error } = apiHook(type); - const taskType = TaskInfos.find((taskType) => taskType.type === type); + const taskInfo = TaskInfos.find((taskType) => taskType.type === type); if (isLoading) { return ; @@ -31,8 +31,8 @@ export const TaskPage = ({ type }: TaskPageProps) => { return ( <> - {t(getTypeSafei18nKey(`${taskType.id}.label`))} - + {t(getTypeSafei18nKey(`${taskInfo.id}.label`))} + From a0f4449e9f9ef1919859a60c30c459eb4a9ba968 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:23:19 +0100 Subject: [PATCH 085/111] added seed to parameters --- inference/worker/__main__.py | 20 +------------------ .../oasst_shared/schemas/inference.py | 2 +- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index e5c15fb4..190ed788 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -41,25 +41,6 @@ def main( prompt = prefix + "\n".join(messages) + "\nAssistant:" - # TODO: use the seed - # torch.manual_seed(work_request.seed) - # model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ - # 0 - # ]["generated_text"] - # model_output = model_output.strip() - - # # fake streaming - # split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] - # pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] - # for piece in pieces: - # if not piece: - # continue - # if piece.strip() in ("User:", "Assistant:"): - # break - # ws.send(inference.WorkResponsePacket(token=piece).json()) - # time.sleep(0.1) - # ws.send(inference.WorkResponsePacket(is_end=True).json()) - response = requests.post( f"{inference_server_url}/generate_stream", json={ @@ -70,6 +51,7 @@ def main( "top_k": work_request.top_k, "top_p": work_request.top_p, "temperature": work_request.temperature, + "seed": work_request.seed, }, }, stream=True, diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index b50cef9c..91a16b61 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -13,7 +13,7 @@ class WorkRequest(pydantic.BaseModel): conversation: protocol.Conversation = pydantic.Field(..., repr=False) model_name: str = "distilgpt2" max_new_tokens: int = 100 - seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1)) do_sample: bool = True top_k: int = 50 top_p: float = 0.9 From ae5d16f3942d5f365dc231bd027295555e4fd392 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:40:21 +0100 Subject: [PATCH 086/111] added a tmux inference dev setup script --- inference/full-dev-setup.sh | 19 +++++++++++++++++++ inference/worker/__main__.py | 2 ++ 2 files changed, 21 insertions(+) create mode 100755 inference/full-dev-setup.sh diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh new file mode 100755 index 00000000..98a5b173 --- /dev/null +++ b/inference/full-dev-setup.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Creates a tmux window with splits for the individual services + +tmux new-session -d -s "inference-dev-setup" +tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m +tmux split-window -h +tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m +tmux split-window -h +tmux send-keys "cd server" C-m +tmux send-keys "uvicorn main:app --reload" C-m +tmux split-window -h +tmux send-keys "cd worker" C-m +tmux send-keys "python __main__.py" C-m +tmux split-window -h +tmux send-keys "cd text-client" C-m +tmux send-keys "sleep 5" C-m +tmux send-keys "python __main__.py" C-m +tmux attach-session -t "inference-dev-setup" diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 190ed788..96fe164a 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -18,8 +18,10 @@ def main( inference_server_url: str = "http://localhost:8001", ): def on_open(ws: websocket.WebSocket): + logger.info("Connected to backend, sending config...") worker_config = inference.WorkerConfig(model_name=model_name) ws.send(worker_config.json()) + logger.info("Config sent, waiting for work...") def on_message(ws: websocket.WebSocket, message: str): # TODO: what if this comes in, but one is already in progress? From 45a4b09eae04d2a6d3f099da155d576d7efb20f9 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:42:01 +0100 Subject: [PATCH 087/111] added setup instructions to readme --- inference/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/inference/README.md b/inference/README.md index bd0272ad..6e1da2c7 100644 --- a/inference/README.md +++ b/inference/README.md @@ -2,7 +2,13 @@ Preliminary implementation of the inference engine for OpenAssistant. -## Development (you'll need multiple terminals) +## Development Variant 1 (you'll need tmux) + +Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait +until the 2nd terminal is ready and says `{"message":"Connected"}` before +entering input into the last terminal. + +## Development Variant 2 (you'll need multiple terminals) Run a redis container (or use the one of the general docker compose file): From 3b04080d7be9f07362a015b5e6b27ff463705bf7 Mon Sep 17 00:00:00 2001 From: James Melvin Ebenezer Date: Fri, 27 Jan 2023 22:36:25 +0530 Subject: [PATCH 088/111] 949_transaction error handling (#950) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: transaction error handling * refactor: retry handling for all decorators as per review comments * fix: raising retry exhausted error * fix: avoid auto refresh on RollBack and review comments * removed refresh_result param from managed_tx_function --------- Co-authored-by: James Melvin Co-authored-by: Andreas Köpf --- backend/oasst_backend/utils/database_utils.py | 168 ++++++++++++------ 1 file changed, 114 insertions(+), 54 deletions(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index fb8bf6c5..5f7a3136 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -7,9 +7,14 @@ from loguru import logger from oasst_backend.config import settings from oasst_backend.database import engine from oasst_shared.exceptions import OasstError, OasstErrorCode -from sqlalchemy.exc import OperationalError +from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation +from sqlalchemy.exc import OperationalError, PendingRollbackError from sqlmodel import Session, SQLModel +""" +Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html +""" + class CommitMode(IntEnum): """ @@ -34,28 +39,46 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s @wraps(f) def wrapped_f(self, *args, **kwargs): try: - for i in range(num_retries): - try: - result = f(self, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) self.db.commit() - elif auto_commit == CommitMode.FLUSH: - self.db.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + self.db.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) self.db.rollback() + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation) + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + self.db.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + self.db.flush() if isinstance(result, SQLModel): self.db.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + elif auto_commit == CommitMode.ROLLBACK: self.db.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + return result except Exception as e: - logger.error("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f @@ -70,28 +93,46 @@ def async_managed_tx_method( @wraps(f) async def wrapped_f(self, *args, **kwargs): try: - for i in range(num_retries): - try: - result = await f(self, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) self.db.commit() - elif auto_commit == CommitMode.FLUSH: - self.db.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + self.db.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) self.db.rollback() + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation) + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + self.db.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + self.db.flush() if isinstance(result, SQLModel): self.db.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + elif auto_commit == CommitMode.ROLLBACK: self.db.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + return result except Exception as e: - logger.exception("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f @@ -107,7 +148,6 @@ def managed_tx_function( auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, session_factory: Callable[..., Session] = default_session_factor, - refresh_result: bool = True, ): """Passes Session object as first argument to wrapped function.""" @@ -115,29 +155,49 @@ def managed_tx_function( @wraps(f) def wrapped_f(*args, **kwargs): try: - for i in range(num_retries): - with session_factory() as session: - try: - result = f(session, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + with session_factory() as session: + try: + result = f(session, *args, **kwargs) session.commit() - elif auto_commit == CommitMode.FLUSH: - session.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + session.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) session.rollback() - if refresh_result and isinstance(result, SQLModel): - session.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") - session.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, + (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation), + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + session.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + with session_factory() as session: + result = f(session, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + session.flush() + if isinstance(result, SQLModel): + session.refresh(result) + elif auto_commit == CommitMode.ROLLBACK: + session.rollback() + return result except Exception as e: - logger.error("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f From ce3b3c7eccd8aadf288b6c67c685c5ec805a97e9 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 28 Jan 2023 00:17:31 +0700 Subject: [PATCH 089/111] pass correct param when fetch leaderboard --- website/src/lib/oasst_api_client.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index bd263400..3cc4f4ef 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -144,7 +144,7 @@ export class OasstApiClient { time_frame: LeaderboardTimeFrame, { limit = 20 }: { limit?: number } ): Promise { - return this.get(`/api/v1/leaderboards/${time_frame}`, { limit }); + return this.get(`/api/v1/leaderboards/${time_frame}`, { max_count: limit }); } /** From d16598725664241c554c466ebf30a9b2900af4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 19:43:08 +0100 Subject: [PATCH 090/111] add missing await to async_managed_tx_method --- backend/oasst_backend/utils/database_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 5f7a3136..196178e5 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -98,7 +98,7 @@ def async_managed_tx_method( retry_exhausted = True for i in range(num_retries): try: - result = f(self, *args, **kwargs) + result = await f(self, *args, **kwargs) self.db.commit() if isinstance(result, SQLModel): self.db.refresh(result) From c7692b9049a25ff00c6e39d2854a125cbbc5411d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 19:44:48 +0100 Subject: [PATCH 091/111] add 2nd missing await to async_managed_tx_method --- backend/oasst_backend/utils/database_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 196178e5..5ac25f50 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -123,7 +123,7 @@ def async_managed_tx_method( http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, ) else: - result = f(self, *args, **kwargs) + result = await f(self, *args, **kwargs) if auto_commit == CommitMode.FLUSH: self.db.flush() if isinstance(result, SQLModel): From bf77d4dc60d1a295c02242feec3e8c9ce1e08835 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 27 Jan 2023 10:51:32 +0000 Subject: [PATCH 092/111] Add types for TaskType to TaskHook Pre-commit Apply better naming to task api hooks Lint --- website/src/components/TaskPage/TaskPage.tsx | 6 ++--- website/src/lib/constants.ts | 3 ++- website/src/types/Hooks.ts | 26 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 website/src/types/Hooks.ts diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx index 41d89405..e1ecc79c 100644 --- a/website/src/components/TaskPage/TaskPage.tsx +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -4,7 +4,7 @@ import { TaskEmptyState } from "src/components/EmptyState"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { TaskInfos } from "src/components/Tasks/TaskTypes"; -import { apiHooksByType, ERROR_CODES } from "src/lib/constants"; +import { ERROR_CODES, taskApiHooks } from "src/lib/constants"; import { getTypeSafei18nKey } from "src/lib/i18n"; import { TaskType } from "src/types/Task"; @@ -14,8 +14,8 @@ type TaskPageProps = { export const TaskPage = ({ type }: TaskPageProps) => { const { t } = useTranslation(["tasks", "common"]); - const apiHook = apiHooksByType[type]; - const { tasks, isLoading, reset, trigger, error } = apiHook(type); + const taskApiHook = taskApiHooks[type]; + const { tasks, isLoading, reset, trigger, error } = taskApiHook(type); const taskInfo = TaskInfos.find((taskType) => taskType.type === type); if (isLoading) { diff --git a/website/src/lib/constants.ts b/website/src/lib/constants.ts index 6269dbf9..a260fa28 100644 --- a/website/src/lib/constants.ts +++ b/website/src/lib/constants.ts @@ -14,6 +14,7 @@ import { useRankInitialPromptsTask, useRankPrompterRepliesTask, } from "src/hooks/tasks/useRankReplies"; +import { TaskApiHooks } from "src/types/Hooks"; import { TaskType } from "src/types/Task"; export const ERROR_CODES = { @@ -28,7 +29,7 @@ export const ERROR_CODES = { TASK_MESSAGE_TOO_LONG: 1008, }; -export const apiHooksByType = { +export const taskApiHooks: TaskApiHooks = { [TaskType.random]: useGenericTaskAPI, [TaskType.assistant_reply]: useCreateAssistantReply, [TaskType.initial_prompt]: useCreateInitialPrompt, diff --git a/website/src/types/Hooks.ts b/website/src/types/Hooks.ts new file mode 100644 index 00000000..8fd9aa4f --- /dev/null +++ b/website/src/types/Hooks.ts @@ -0,0 +1,26 @@ +import { MutatorCallback, MutatorOptions } from "swr"; + +import { BaseTask, TaskResponse, TaskType } from "./Task"; + +type ConcreteTaskResponse = TaskResponse; +type TaskError = { errorCode: number; message: string }; + +type Trigger = ( + extraArgument?: unknown, + options?: MutatorOptions +) => Promise; + +type Reset = ( + data?: ConcreteTaskResponse | Promise | MutatorCallback, + opts?: boolean | MutatorOptions +) => Promise; + +type TaskAPIHook = { + tasks: TaskResponse[]; + isLoading: boolean; + error: TaskError; + trigger: Trigger; + reset: Reset; +}; + +export type TaskApiHooks = Record TaskAPIHook>; From 356fd775e93505b32535ca628fa142f987ab0415 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Sat, 28 Jan 2023 06:52:40 +1100 Subject: [PATCH 093/111] Add emoji reactions and reporting for messages (website) (#952) * website: Move labelling to message ... menu and add reporting and emoji reactions We can add more emoji easily in future, we just need to pick ones that we have consistent icons for. Also added "open in new tab" option so that messages can be navigated to from tasks on mobile. * website: Make new label and report strings translatable. * website: Move report api call to oasst client * small fixes * pre-commit --------- Co-authored-by: AbdBarho --- website/public/locales/en/message.json | 11 ++ website/src/components/Messages.tsx | 2 +- .../src/components/Messages/LabelPopup.tsx | 76 ++++++++ .../Messages/MessageEmojiButton.stories.tsx | 34 ++++ .../Messages/MessageEmojiButton.tsx | 48 +++++ .../Messages/MessageTable.stories.tsx | 23 ++- .../src/components/Messages/MessageTable.tsx | 6 +- .../Messages/MessageTableEntry.stories.tsx | 25 ++- .../components/Messages/MessageTableEntry.tsx | 176 +++++++++++++++--- .../Messages/MessageWithChildren.tsx | 6 +- .../src/components/Messages/ReportPopup.tsx | 56 ++++++ website/src/components/Tasks/EvaluateTask.tsx | 1 - .../components/Tasks/LabelTask/LabelTask.tsx | 2 +- website/src/lib/oasst_api_client.ts | 34 +++- website/src/pages/api/messages/[id]/emoji.ts | 30 +++ website/src/pages/api/messages/[id]/index.ts | 11 +- website/src/pages/api/report.ts | 24 +++ website/src/pages/api/set_label.ts | 6 +- website/src/pages/messages/[id]/index.tsx | 2 +- website/src/types/Conversation.ts | 16 +- website/styles/Theme/colors.tsx | 2 + website/types/i18next.d.ts | 4 +- 22 files changed, 541 insertions(+), 54 deletions(-) create mode 100644 website/public/locales/en/message.json create mode 100644 website/src/components/Messages/LabelPopup.tsx create mode 100644 website/src/components/Messages/MessageEmojiButton.stories.tsx create mode 100644 website/src/components/Messages/MessageEmojiButton.tsx create mode 100644 website/src/components/Messages/ReportPopup.tsx create mode 100644 website/src/pages/api/messages/[id]/emoji.ts create mode 100644 website/src/pages/api/report.ts diff --git a/website/public/locales/en/message.json b/website/public/locales/en/message.json new file mode 100644 index 00000000..45ea04a1 --- /dev/null +++ b/website/public/locales/en/message.json @@ -0,0 +1,11 @@ +{ + "reactions": "Reactions", + "label_action": "Label", + "label_title": "Label", + "submit_labels": "Submit", + "open_new_tab_action": "Open in new tab", + "report_title": "Report", + "report_action": "Report", + "report_placeholder": "Why should this message be reviewed?", + "send_report": "Send" +} diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 0934c5af..c9d77e3c 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -20,7 +20,7 @@ export const Messages = ({ messages }: MessagesProps) => { return {items}; }; -export const MessageView = forwardRef((message: Message, ref) => { +export const MessageView = forwardRef, "div">((message: Partial, ref) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { diff --git a/website/src/components/Messages/LabelPopup.tsx b/website/src/components/Messages/LabelPopup.tsx new file mode 100644 index 00000000..b2b95278 --- /dev/null +++ b/website/src/components/Messages/LabelPopup.tsx @@ -0,0 +1,76 @@ +import { + Button, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, +} from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { useState } from "react"; +import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; +import { get, post } from "src/lib/api"; +import useSWRImmutable from "swr/immutable"; +import useSWRMutation from "swr/mutation"; + +interface LabelMessagePopupProps { + messageId: string; + show: boolean; + onClose: () => void; +} + +interface Label { + name: string; + display_text: string; + help_text: string; +} + +interface ValidLabelsResponse { + valid_labels: Label[]; +} + +export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => { + const { t } = useTranslation("message"); + const { data: response } = useSWRImmutable("/api/valid_labels", get); + const valid_labels = response?.valid_labels ?? []; + const [values, setValues] = useState(null); + + const { trigger: setLabels } = useSWRMutation("/api/set_label", post); + + const submit = () => { + const label_map: Map = new Map(); + console.assert(valid_labels.length === values.length); + values.forEach((value, idx) => { + if (value !== null) { + label_map.set(valid_labels[idx].name, value); + } + }); + setLabels({ + message_id: messageId, + label_map: Object.fromEntries(label_map), + }); + + setValues(null); + onClose(); + }; + + return ( + + + + {t("label_title")} + + + name)} onChange={setValues} /> + + + + + + + ); +}; diff --git a/website/src/components/Messages/MessageEmojiButton.stories.tsx b/website/src/components/Messages/MessageEmojiButton.stories.tsx new file mode 100644 index 00000000..d74836d5 --- /dev/null +++ b/website/src/components/Messages/MessageEmojiButton.stories.tsx @@ -0,0 +1,34 @@ +import React from "react"; + +import { MessageEmojiButton } from "./MessageEmojiButton"; + +// eslint-disable-next-line import/no-anonymous-default-export +export default { + title: "Messages/MessageEmojiButton", + component: MessageEmojiButton, +}; + +const Template = ({ emoji, count, checked }: { emoji: string; count: number; checked?: boolean }) => { + return ; +}; + +export const Default = Template.bind({}); +Default.args = { + emoji: "+1", + count: 7, + checked: false, +}; + +export const BigNumber = Template.bind({}); +BigNumber.args = { + emoji: "+1", + count: 999, + checked: false, +}; + +export const Checked = Template.bind({}); +Checked.args = { + emoji: "+1", + count: 2, + checked: true, +}; diff --git a/website/src/components/Messages/MessageEmojiButton.tsx b/website/src/components/Messages/MessageEmojiButton.tsx new file mode 100644 index 00000000..8b5c9ff7 --- /dev/null +++ b/website/src/components/Messages/MessageEmojiButton.tsx @@ -0,0 +1,48 @@ +import { Button } from "@chakra-ui/react"; +import { BoxSelect, Flag, LucideProps, ThumbsDown, ThumbsUp } from "lucide-react"; +import { ReactElement } from "react"; +import { MessageEmoji } from "src/types/Conversation"; + +type EmojiIconPurpose = "MINI_BUTTON" | "NORMAL"; + +const defaultIconProps: (purpose: EmojiIconPurpose) => LucideProps = (purpose: EmojiIconPurpose) => { + if (purpose === "MINI_BUTTON") return { height: "1em" }; + return {}; +}; + +export const getEmojiIcon = (name: string, purpose: EmojiIconPurpose): ReactElement => { + switch (name) { + case "+1": + return ; + case "-1": + return ; + case "flag": + case "red_flag": + return ; + default: + return ; + } +}; + +interface MessageEmojiButtonProps { + emoji: MessageEmoji; + checked?: boolean; + onClick: () => void; +} + +export const MessageEmojiButton = ({ emoji, checked, onClick }: MessageEmojiButtonProps) => { + return ( + + ); +}; diff --git a/website/src/components/Messages/MessageTable.stories.tsx b/website/src/components/Messages/MessageTable.stories.tsx index 6f383b01..bc03aed1 100644 --- a/website/src/components/Messages/MessageTable.stories.tsx +++ b/website/src/components/Messages/MessageTable.stories.tsx @@ -29,18 +29,24 @@ Default.args = { is_assistant: true, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?", is_assistant: false, id: "", frontend_message_id: "", + emojis: { "-1": 11, red_flag: 2 }, + user_emojis: [], }, { text: "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: true, @@ -50,12 +56,21 @@ Default.args = { export const Conversation = Template.bind({}); Conversation.args = { messages: [ - { text: "Hello! How can I help you?", is_assistant: true, id: "", frontend_message_id: "" }, + { + text: "Hello! How can I help you?", + is_assistant: true, + id: "", + frontend_message_id: "", + emojis: {}, + user_emojis: [], + }, { text: "Who were the 8 presidents before George Washington?", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: false, @@ -70,18 +85,24 @@ LongText.args = { is_assistant: true, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "Yes, I think they can be helpful when the child misbehaves, but they should be used with a little bit of compassion and understanding that it\u2019s not the natural state of things to have an adult yelling at them. Time outs are also often used without letting the child know how they\u2019re getting out of the time out, which can make it feel arbitrary or like a punishment, rather than a consequence for something they did. It\u2019s really easy for adults to do this kind of thing unconsciously. It\u2019s easy to get caught up in the notion that \u201cThey\u2019re in time out, and that\u2019s the end of it!\u201d but kids can be pretty imaginative, and they can use their own creativity to make their way out of time outs. A compassionate time out ends when the child shows a sign of understanding what they\u2019ve done wrong, and are ready to begin again. That way the child knows they\u2019re learning, and that the parent is seeing them as an intelligent person, even if they sometimes mess up. You can still use the other techniques you were using to be tough when necessary, but using a compassionate approach will let you use them without actually using them!", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "No. The USA was founded by a Puritan group of Protestants, but it didn\u2019t adopt the religion of the Puritans until much later, and it was always a secular state. The Puritans observed the Sabbath on Sunday, and the Puritans only had a small influence in the early history of the USA. It\u2019s difficult to trace the origins of closing stores on Sunday, but one early and short-lived attempt at forcing the Sabbath on people in the 1800s was motivated by the Protestant ideal that people should spend Sunday focusing on spiritual activities. By the mid-1800s, when the Sunday closing law was made, there was not a lot of pressure from that standpoint, but the church had begun to advocate for Sunday closing laws as a way of counteracting the negative effects of industrialization on the day of rest. Even after that shift, closing stores on Sunday was not always possible, since the religious Sunday was not always chosen for observance. And as industrialization accelerated and mechanization made it possible to operate stores on Sunday, the law was not enforced as much as people liked. The day of rest was also being violated by stores that stayed open all day on Sunday, so closing stores on Sundays became an effort to protect the Sabbath for all citizens.", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: true, diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index acf92e05..2d39f346 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -11,11 +11,11 @@ interface MessageTableProps { export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) { return ( - {messages.map((item, idx) => ( + {messages.map((message, idx) => ( ))} diff --git a/website/src/components/Messages/MessageTableEntry.stories.tsx b/website/src/components/Messages/MessageTableEntry.stories.tsx index b6071dd7..3550d00f 100644 --- a/website/src/components/Messages/MessageTableEntry.stories.tsx +++ b/website/src/components/Messages/MessageTableEntry.stories.tsx @@ -1,4 +1,5 @@ import React from "react"; +import { Message } from "src/types/Conversation"; import { MessageTableEntry } from "./MessageTableEntry"; @@ -8,10 +9,8 @@ export default { component: MessageTableEntry, }; -const Template = ({ text, is_assistant, id, frontend_message_id, enabled, highlight }) => { - return ( - - ); +const Template = ({ enabled, highlight, ...message }) => { + return ; }; export const Default = Template.bind({}); @@ -22,6 +21,8 @@ Default.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], }; export const Asistant = Template.bind({}); @@ -32,6 +33,8 @@ Asistant.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], }; export const LongText = Template.bind({}); @@ -42,4 +45,18 @@ LongText.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], +}; + +export const WithEmoji = Template.bind({}); +WithEmoji.args = { + text: "As you\u2019ve mentioned, Star Wars has many sequels, prequels, and crossovers. The official list of movies in Star Wars is:", + is_assistant: true, + id: "", + frontend_message_id: "", + enabled: true, + highlight: false, + emojis: { "-1": 5, "+1": 1 }, + user_emojis: ["-1"], }; diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 77202c44..3cde48f6 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,23 +1,47 @@ -import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; +import { + Avatar, + Box, + HStack, + Menu, + MenuButton, + MenuDivider, + MenuGroup, + MenuItem, + MenuList, + SimpleGrid, + useBreakpointValue, + useColorModeValue, + useDisclosure, +} from "@chakra-ui/react"; import { boolean } from "boolean"; +import { ClipboardList, Flag, MessageSquare, MoreHorizontal } from "lucide-react"; import { useRouter } from "next/router"; -import { useCallback, useMemo } from "react"; -import { FlaggableElement } from "src/components/FlaggableElement"; -import { Message } from "src/types/Conversation"; +import { useTranslation } from "next-i18next"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { LabelMessagePopup } from "src/components/Messages/LabelPopup"; +import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton"; +import { ReportPopup } from "src/components/Messages/ReportPopup"; +import { post } from "src/lib/api"; +import { Message, MessageEmojis } from "src/types/Conversation"; import { colors } from "styles/Theme/colors"; +import useSWRMutation from "swr/mutation"; interface MessageTableEntryProps { - item: Message; + message: Message; enabled?: boolean; highlight?: boolean; } -export function MessageTableEntry(props: MessageTableEntryProps) { +export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) { const router = useRouter(); + const [emojis, setEmojis] = useState({ emojis: {}, user_emojis: [] }); + useEffect(() => { + setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis }); + }, [message.emojis, message.user_emojis]); - const { item } = props; - - const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]); + const goToMessage = useCallback(() => router.push(`/messages/${message.id}`), [router, message.id]); + const { isOpen: reportPopupOpen, onOpen: showReportPopup, onClose: closeReportPopup } = useDisclosure(); + const { isOpen: labelPopupOpen, onOpen: showLabelPopup, onClose: closeLabelPopup } = useDisclosure(); const backgroundColor = useColorModeValue("gray.100", "gray.700"); const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B"); @@ -32,34 +56,124 @@ export function MessageTableEntry(props: MessageTableEntryProps) { borderColor={borderColor} size={inlineAvatar ? "xs" : "sm"} mr={inlineAvatar ? 2 : 0} - name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`} - src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`} + name={`${boolean(message.is_assistant) ? "Assistant" : "User"}`} + src={`${boolean(message.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`} /> ), - [borderColor, inlineAvatar, item.is_assistant] + [borderColor, inlineAvatar, message.is_assistant] ); const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight); + const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, { + onSuccess: setEmojis, + }); + const react = (emoji: string, state: boolean) => { + sendEmojiChange({ op: state ? "add" : "remove", emoji }); + }; + return ( - - - {!inlineAvatar && avatar} - + {!inlineAvatar && avatar} + + {inlineAvatar && avatar} + {message.text} + e.stopPropagation()} > - {inlineAvatar && avatar} - {item.text} - - - + {Object.entries(emojis.emojis).map(([emoji, count]) => ( + react(emoji, !emojis.user_emojis.includes(emoji))} + /> + ))} + + + + + + ); } + +const EmojiMenuItem = ({ + emoji, + checked, + react, +}: { + emoji: string; + checked?: boolean; + react: (emoji: string, state: boolean) => void; +}) => { + const activeColor = useColorModeValue(colors.light.active, colors.dark.active); + + return ( + react(emoji, !checked)} justifyContent="center" color={checked ? activeColor : undefined}> + {getEmojiIcon(emoji, "NORMAL")} + + ); +}; + +const MessageActions = ({ + react, + userEmoji, + onLabel, + onReport, + messageId, +}: { + react: (emoji: string, state: boolean) => void; + userEmoji: string[]; + onLabel: () => void; + onReport: () => void; + messageId: string; +}) => { + const { t } = useTranslation("message"); + + return ( + + + + + + + + {["+1", "-1"].map((emoji) => ( + + ))} + + + + }> + {t("label_action")} + + }> + {t("report_action")} + + + }> + {t("open_new_tab_action")} + + + + ); +}; diff --git a/website/src/components/Messages/MessageWithChildren.tsx b/website/src/components/Messages/MessageWithChildren.tsx index ca29c410..f47cfac5 100644 --- a/website/src/components/Messages/MessageWithChildren.tsx +++ b/website/src/components/Messages/MessageWithChildren.tsx @@ -52,7 +52,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { {isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"} - + @@ -86,9 +86,9 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { gap="4" shadow="base" > - {children.map((item, idx) => ( + {children.map((message, idx) => ( - + ))} diff --git a/website/src/components/Messages/ReportPopup.tsx b/website/src/components/Messages/ReportPopup.tsx new file mode 100644 index 00000000..67a2ea1b --- /dev/null +++ b/website/src/components/Messages/ReportPopup.tsx @@ -0,0 +1,56 @@ +import { + Button, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Textarea, +} from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { useState } from "react"; +import { post } from "src/lib/api"; +import useSWRMutation from "swr/mutation"; + +interface ReportPopupProps { + messageId: string; + show: boolean; + onClose: () => void; +} + +export const ReportPopup = ({ messageId, show, onClose }: ReportPopupProps) => { + const { t } = useTranslation("message"); + const [text, setText] = useState(""); + const { trigger } = useSWRMutation("/api/report", post); + + const submit = () => { + trigger({ + message_id: messageId, + text, + }); + + setText(""); + onClose(); + }; + + return ( + + + + {t("report_title")} + + +