From df1eca4eafd18381b965fa78bd55631e99c7eb88 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Thu, 19 Jan 2023 17:21:41 +0900 Subject: [PATCH 01/15] Ensuring the website uses the most specific auth type with the backend when fetching and interacting with tasks --- website/src/lib/oasst_api_client.ts | 19 +++------- website/src/lib/users.ts | 38 +++++++++++++++++++ website/src/pages/api/new_task/[task_type].ts | 4 +- website/src/pages/api/update_task.ts | 5 ++- website/src/types/Users.ts | 10 +++-- 5 files changed, 56 insertions(+), 20 deletions(-) create mode 100644 website/src/lib/users.ts diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index d48a987c..de58ae71 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,7 +1,6 @@ -import { JWT } from "next-auth/jwt"; import type { Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; -import type { BackendUser } from "src/types/Users"; +import type { BackendUser, BackendUserCore } from "src/types/Users"; export class OasstError { message: string; @@ -108,14 +107,10 @@ export class OasstApiClient { // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. // This is a raw Json type, so we can't use it to strongly type the task. - async fetchTask(taskType: string, userToken: JWT): Promise { + async fetchTask(taskType: string, user: BackendUserCore): Promise { return this.post("/api/v1/tasks/", { type: taskType, - user: { - id: userToken.sub, - display_name: userToken.name, - auth_method: "local", - }, + user, }); } @@ -140,15 +135,11 @@ export class OasstApiClient { messageId: string, userMessageId: string, content: object, - userToken: JWT + user: BackendUserCore ): Promise { return this.post("/api/v1/tasks/interaction", { type: updateType, - user: { - id: userToken.sub, - display_name: userToken.name, - auth_method: "local", - }, + user, task_id: taskId, message_id: messageId, user_message_id: userMessageId, diff --git a/website/src/lib/users.ts b/website/src/lib/users.ts new file mode 100644 index 00000000..2aa8c708 --- /dev/null +++ b/website/src/lib/users.ts @@ -0,0 +1,38 @@ +import prisma from "src/lib/prismadb"; +import type { BackendUserCore } from "src/types/Users"; + +/** + * Returns a `BackendUserCore` that can be used for interacting with the Backend service. + * + * @param {string} id The user's web auth id. + * + * @return {BackendUserCore} The most specific auth type and id for the user. + */ +const getBackendUserCore = async (id: string) => { + const user = await prisma.user.findUnique({ + where: { id }, + select: { + id: true, + name: true, + accounts: true, + }, + }); + + // If there are no linked accounts, just use what we have locally. + if (user.accounts.length === 0) { + return { + id: user.id, + display_name: user.name, + auth_method: "local", + } as BackendUserCore; + } + + // Otherwise, use the first linked account that the user created. + return { + id: user.accounts[0].providerAccountId, + display_name: user.name, + auth_method: user.accounts[0].provider, + } as BackendUserCore; +}; + +export { getBackendUserCore }; diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index e77c5eb2..c8255b18 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,6 +1,7 @@ import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { getBackendUserCore } from "src/lib/users"; /** * Returns a new task created from the Task Backend. We do a few things here: @@ -14,9 +15,10 @@ const handler = withoutRole("banned", async (req, res, token) => { // Fetch the new task. const { task_type } = req.query; + const user = await getBackendUserCore(token.sub); let task; try { - task = await oasstApiClient.fetchTask(task_type as string, token); + task = await oasstApiClient.fetchTask(task_type as string, user); } catch (err) { console.error(err); res.status(500).json(err); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 02982daa..b9de5c50 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -2,6 +2,8 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { getBackendUserCore } from "src/lib/users"; +import type { BackendUserCore } from "src/types/Users"; /** * Stores the task interaction with the Task Backend and then returns the next task generated. @@ -39,9 +41,10 @@ const handler = withoutRole("banned", async (req, res, token) => { }, }); + const user = await getBackendUserCore(token.sub); let newTask; try { - newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, token); + newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, user); } catch (err) { console.error(JSON.stringify(err)); return res.status(500).json(err); diff --git a/website/src/types/Users.ts b/website/src/types/Users.ts index eeb1903a..39d2a663 100644 --- a/website/src/types/Users.ts +++ b/website/src/types/Users.ts @@ -1,7 +1,4 @@ -/** - * Reports the Backend's knowledge of a user. - */ -export interface BackendUser { +export interface BackendUserCore { /** * The user's unique ID according to the `auth_method`. */ @@ -18,7 +15,12 @@ export interface BackendUser { * - local */ auth_method: string; +} +/** + * Reports the Backend's knowledge of a user. + */ +export interface BackendUser extends BackendUserCore { /** * The backend's UUID for this user. */ From ae6600afa2dfa428bcbf301dfbc8c196e8e89061 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Thu, 19 Jan 2023 17:23:53 +0900 Subject: [PATCH 02/15] Removing an unused import --- website/src/pages/api/update_task.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index b9de5c50..c547503a 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -3,7 +3,6 @@ import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import { getBackendUserCore } from "src/lib/users"; -import type { BackendUserCore } from "src/types/Users"; /** * Stores the task interaction with the Task Backend and then returns the next task generated. From 581f31203ca94de45ba55e12c0a8c1a944ca3a74 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Thu, 19 Jan 2023 17:32:05 +0900 Subject: [PATCH 03/15] Fixing the contract tests to use the new user type --- .../contract/oasst_api_contract_tests.cy.ts | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/website/cypress/contract/oasst_api_contract_tests.cy.ts b/website/cypress/contract/oasst_api_contract_tests.cy.ts index cf1c7506..d2ffeba3 100644 --- a/website/cypress/contract/oasst_api_contract_tests.cy.ts +++ b/website/cypress/contract/oasst_api_contract_tests.cy.ts @@ -1,34 +1,27 @@ import { OasstApiClient, OasstError } from "src/lib/oasst_api_client"; +import type { BackendUserCore } from "src/types/Users"; describe("Contract test for Oasst API", function () { // Assumes this is running the mock server. const oasstApiClient = new OasstApiClient("http://localhost:8080", "test"); + const testUser = { + id: "abcd", + display_name: "test", + auth_method: "local", + } as BackendUserCore; + it("can fetch a task", async () => { - expect( - await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }) - ).to.be.not.null; + expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null; }); it("can ack a task", async () => { - const task = await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }); + const task = await oasstApiClient.fetchTask("random", testUser); expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null; }); it("can record a taskInteraction", async () => { - const task = await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }); + const task = await oasstApiClient.fetchTask("random", testUser); expect( await oasstApiClient.interactTask( "text_reply_to_message", @@ -36,11 +29,7 @@ describe("Contract test for Oasst API", function () { "321", "1", { text: "Test" }, - { - sub: "test", - name: "test", - email: "test", - } + testUser ) ).to.be.not.null; }); From 3fa2e637d26d8a132c974d6232c497c3b3a55298 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Thu, 19 Jan 2023 16:01:36 +0000 Subject: [PATCH 04/15] Add localization to pages, header and footer Fix import orders for useTranslation Apply common translations for header and footer lint Fix getServerSideProps messages/id --- website/public/locales/en/common.json | 15 ++++++++++++- website/src/components/Footer.tsx | 22 ++++++++++--------- website/src/components/Header/Header.tsx | 8 ++++--- website/src/components/Header/UserMenu.tsx | 20 +++++++++-------- website/src/components/Hero.tsx | 2 +- website/src/components/Layout.tsx | 6 ++--- website/src/pages/404.tsx | 7 ++++++ website/src/pages/500.tsx | 7 ++++++ website/src/pages/about.tsx | 7 ++++++ website/src/pages/account/edit.tsx | 7 ++++++ website/src/pages/account/index.tsx | 7 ++++++ website/src/pages/admin/index.tsx | 7 ++++++ website/src/pages/admin/manage_user/[id].tsx | 4 +++- website/src/pages/auth/signin.tsx | 7 +++--- website/src/pages/create/assistant_reply.tsx | 7 ++++++ website/src/pages/create/initial_prompt.tsx | 7 ++++++ website/src/pages/create/user_reply.tsx | 7 ++++++ website/src/pages/dashboard.tsx | 7 ++++++ .../pages/evaluate/rank_assistant_replies.tsx | 7 ++++++ .../pages/evaluate/rank_initial_prompts.tsx | 7 ++++++ .../src/pages/evaluate/rank_user_replies.tsx | 7 ++++++ .../src/pages/label/label_assistant_reply.tsx | 7 ++++++ .../src/pages/label/label_initial_prompt.tsx | 7 ++++++ .../src/pages/label/label_prompter_reply.tsx | 7 ++++++ website/src/pages/leaderboard.tsx | 7 ++++++ website/src/pages/messages/[id]/index.tsx | 14 +++++++----- website/src/pages/messages/index.tsx | 7 ++++++ website/src/pages/privacy-policy.tsx | 7 ++++++ website/src/pages/terms-of-service.tsx | 7 ++++++ 29 files changed, 202 insertions(+), 36 deletions(-) diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 0b2df79c..e18eb8ec 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -1,4 +1,17 @@ { + "about": "About", + "account_settings": "Account Settings", + "connect": "Connect", + "conversational": "Conversational AI for everyone.", + "dashboard": "Dashboard", "discord": "Discord", - "github": "GitHub" + "docs": "Docs", + "github": "GitHub", + "legal": "Legal", + "privacy_policy": "Privacy Policy", + "report_a_bug": "Report a Bug", + "sign_in": "Sign In", + "sign_out": "Sign Out", + "terms_of_service": "Terms of Service", + "title": "Open Assistant" } diff --git a/website/src/components/Footer.tsx b/website/src/components/Footer.tsx index b239708a..68cd7c01 100644 --- a/website/src/components/Footer.tsx +++ b/website/src/components/Footer.tsx @@ -1,9 +1,11 @@ import { Box, Divider, Flex, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; +import { useTranslation } from "next-i18next"; import { useMemo } from "react"; export function Footer() { + const { t } = useTranslation(); const { colorMode } = useColorMode(); const backgroundColor = colorMode === "light" ? "white" : "gray.800"; const textColor = colorMode === "light" ? "black" : "gray.300"; @@ -33,10 +35,10 @@ export function Footer() { - Open Assistant + {t("title")} - Conversational AI for everyone. + {t("conversational")} @@ -45,23 +47,23 @@ export function Footer() { - Legal + {t("legal")} - - + + - Connect + {t("connect")} - - + + - About + {t("about")} - + diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index de5abec2..a1b36123 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -1,7 +1,8 @@ -import { Box, Button, Text, Flex } from "@chakra-ui/react"; +import { Box, Button, Flex, Text } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; import { useSession } from "next-auth/react"; +import { useTranslation } from "next-i18next"; import { Flags } from "react-feature-flags"; import { FaUser } from "react-icons/fa"; @@ -23,7 +24,8 @@ function AccountButton() { ); } -export function Header(props) { +export function Header() { + const { t } = useTranslation(); const { data: session } = useSession(); const homeURL = session ? "/dashboard" : "/"; @@ -34,7 +36,7 @@ export function Header(props) { logo - Open Assistant + {t("title")} diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 99ec01f1..6fdde69e 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -13,6 +13,7 @@ import { } from "@chakra-ui/react"; import NextLink from "next/link"; import { signOut, useSession } from "next-auth/react"; +import { useTranslation } from "next-i18next"; import React, { ElementType, useCallback } from "react"; import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi"; @@ -25,6 +26,7 @@ interface MenuOption { } export function UserMenu() { + const { t } = useTranslation(); const borderColor = useColorModeValue("gray.300", "gray.600"); const handleSignOut = useCallback(() => { signOut({ callbackUrl: "/" }); @@ -36,23 +38,23 @@ export function UserMenu() { } const options: MenuOption[] = [ { - name: "Dashboard", + name: t("dashboard"), href: "/dashboard", - desc: "Dashboard", + desc: t("dashboard"), icon: FiLayout, isExternal: false, }, { - name: "Account Settings", + name: t("account_settings"), href: "/account", - desc: "Account Settings", + desc: t("account_settings"), icon: FiSettings, isExternal: false, }, { - name: "Report a Bug", + name: t("report_a_bug"), href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose", - desc: "Report a Bug", + desc: t("report_a_bug"), icon: FiAlertTriangle, isExternal: true, }, @@ -60,9 +62,9 @@ export function UserMenu() { if (session.user.role === "admin") { options.unshift({ - name: "Admin Dashboard", + name: t("admin_dashboard"), href: "/admin", - desc: "Admin Dashboard", + desc: t("admin_dashboard"), icon: FiShield, isExternal: false, }); @@ -105,7 +107,7 @@ export function UserMenu() { diff --git a/website/src/components/Hero.tsx b/website/src/components/Hero.tsx index 4605e9e2..d401e47e 100644 --- a/website/src/components/Hero.tsx +++ b/website/src/components/Hero.tsx @@ -2,8 +2,8 @@ import { Box, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import { useTranslation } from "next-i18next"; -import { Container } from "./Container"; import { AnimatedCircles } from "./AnimatedCircles"; +import { Container } from "./Container"; export function Hero() { const { t } = useTranslation("index"); diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 70a2ce2c..484d16ec 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -23,7 +23,7 @@ export const getDefaultLayout = (page: React.ReactElement) => ( export const getTransparentHeaderLayout = (page: React.ReactElement) => (
-
+
{page}
@@ -31,7 +31,7 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => ( export const getDashboardLayout = (page: React.ReactElement) => ( -
+
( export const getAdminLayout = (page: React.ReactElement) => (
-
+
({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default Error; diff --git a/website/src/pages/500.tsx b/website/src/pages/500.tsx index 49eb2950..bd0fac2f 100644 --- a/website/src/pages/500.tsx +++ b/website/src/pages/500.tsx @@ -1,5 +1,6 @@ import { Box, Button, Center, Link, Text } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { FiAlertTriangle } from "react-icons/fi"; import { EmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; @@ -43,4 +44,10 @@ function ServerError() { ServerError.getLayout = getTransparentHeaderLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default ServerError; diff --git a/website/src/pages/about.tsx b/website/src/pages/about.tsx index 490a6095..08d2bea7 100644 --- a/website/src/pages/about.tsx +++ b/website/src/pages/about.tsx @@ -1,4 +1,5 @@ import Image from "next/image"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { CallToAction } from "src/components/CallToAction"; import { Container } from "src/components/Container"; import Roadmap from "src/components/Roadmap"; @@ -36,4 +37,10 @@ const AboutPage = () => { ); }; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default AboutPage; diff --git a/website/src/pages/account/edit.tsx b/website/src/pages/account/edit.tsx index fe8e8981..9120322a 100644 --- a/website/src/pages/account/edit.tsx +++ b/website/src/pages/account/edit.tsx @@ -2,6 +2,7 @@ import { Button, Input, InputGroup } from "@chakra-ui/react"; import Head from "next/head"; import Router from "next/router"; import { useSession } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import React from "react"; import { Control, useForm, useWatch } from "react-hook-form"; @@ -30,6 +31,12 @@ export default function Account() { ); } +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + const EditForm = () => { const { data: session } = useSession(); diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx index d26fc842..b6e95594 100644 --- a/website/src/pages/account/index.tsx +++ b/website/src/pages/account/index.tsx @@ -2,6 +2,7 @@ import { Button } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; import { useSession } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import React from "react"; export default function Account() { @@ -31,3 +32,9 @@ export default function Account() { ); } + +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index 9cbea222..ef5abe2c 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -1,6 +1,7 @@ import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { useEffect } from "react"; import { getAdminLayout } from "src/components/Layout"; import UsersCell from "src/components/UsersCell"; @@ -44,4 +45,10 @@ const AdminIndex = () => { AdminIndex.getLayout = getAdminLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default AdminIndex; diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index 88bfced4..b53bb7c0 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -3,6 +3,7 @@ import { InferGetServerSidePropsType } from "next"; import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { useEffect } from "react"; import { useForm } from "react-hook-form"; import { getAdminLayout } from "src/components/Layout"; @@ -111,7 +112,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType { @@ -151,7 +151,7 @@ function Signin({ providers }: SigninProps) { Signin.getLayout = (page) => (
-
+
{page}
@@ -209,11 +209,12 @@ const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSaf ); }; -export const getServerSideProps: GetServerSideProps = async () => { +export const getServerSideProps: GetServerSideProps = async ({ locale }) => { const providers = await getProviders(); return { props: { providers, + ...(await serverSideTranslations(locale, ["common"])), }, }; }; diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index cceeaf4e..9002ad3e 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const AssistantReply = () => { AssistantReply.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default AssistantReply; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 6a51ca25..496e000c 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const InitialPrompt = () => { InitialPrompt.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default InitialPrompt; diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 8d2981e5..86f29826 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const UserReply = () => { UserReply.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default UserReply; diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 78a47fd4..ed2b20e1 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,5 +1,6 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; @@ -22,4 +23,10 @@ const Dashboard = () => { Dashboard.getLayout = (page) => getDashboardLayout(page); +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default Dashboard; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 695fbfdc..dcd2d30b 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const RankAssistantReplies = () => { RankAssistantReplies.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default RankAssistantReplies; diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 4eaaa110..4c997faa 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const RankInitialPrompts = () => { RankInitialPrompts.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default RankInitialPrompts; diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index dd23030e..0982d36d 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const RankUserReplies = () => { RankUserReplies.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default RankUserReplies; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 5cb45278..6f005f33 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const LabelAssistantReply = () => { LabelAssistantReply.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default LabelAssistantReply; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index d7c1d4b2..a6813499 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const LabelInitialPrompt = () => { LabelInitialPrompt.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default LabelInitialPrompt; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index b48e6aab..f1ba8008 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,4 +1,5 @@ import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; @@ -29,4 +30,10 @@ const LabelPrompterReply = () => { LabelPrompterReply.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default LabelPrompterReply; diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index e53b0c52..d6bae8e9 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -1,5 +1,6 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; @@ -45,4 +46,10 @@ const Leaderboard = () => { Leaderboard.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default Leaderboard; diff --git a/website/src/pages/messages/[id]/index.tsx b/website/src/pages/messages/[id]/index.tsx index f55c03cc..51c28c42 100644 --- a/website/src/pages/messages/[id]/index.tsx +++ b/website/src/pages/messages/[id]/index.tsx @@ -1,5 +1,6 @@ import { Box, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { MessageLoading } from "src/components/Loading/MessageLoading"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; @@ -48,10 +49,13 @@ const MessageDetail = ({ id }: { id: string }) => { ); }; -MessageDetail.getInitialProps = async ({ query }) => { - const { id } = query; - return { id }; -}; - MessageDetail.getLayout = (page) => getDashboardLayout(page); + +export const getServerSideProps = async ({ locale, query }) => ({ + props: { + id: query.id, + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default MessageDetail; diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 627a8b18..40497fd1 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -1,5 +1,6 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { MessageTable } from "src/components/Messages/MessageTable"; import { get } from "src/lib/api"; @@ -54,4 +55,10 @@ const MessagesDashboard = () => { MessagesDashboard.getLayout = getDashboardLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default MessagesDashboard; diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx index 1c94b669..2f171483 100644 --- a/website/src/pages/privacy-policy.tsx +++ b/website/src/pages/privacy-policy.tsx @@ -1,5 +1,6 @@ import { Box, Heading, Link, Stack, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; @@ -224,4 +225,10 @@ const PrivacyPolicy = () => { PrivacyPolicy.getLayout = getTransparentHeaderLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default PrivacyPolicy; diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx index b0e298ba..3a414292 100644 --- a/website/src/pages/terms-of-service.tsx +++ b/website/src/pages/terms-of-service.tsx @@ -1,5 +1,6 @@ import { Box, Heading, Stack } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; @@ -189,4 +190,10 @@ const TermsOfService = () => { TermsOfService.getLayout = getTransparentHeaderLayout; +export const getStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default TermsOfService; From 74cb9aaa5af31185aca4e86c9f1990ea6489adad Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 20 Jan 2023 03:02:07 +0000 Subject: [PATCH 05/15] [feature] added translation, rallio instruct tuning dataset, prosocial for safety, new summary dataset --- .../instructor/configs/deberta-v3-base.yml | 2 +- .../supervised_finetuning/configs/config.yaml | 8 + .../custom_datasets/README.md | 26 +++ .../custom_datasets/__init__.py | 46 ++++- .../custom_datasets/dialogue_collator.py | 1 + .../custom_datasets/prompt_dialogue.py | 54 ++++++ .../custom_datasets/qa_datasets.py | 22 ++- .../custom_datasets/summarization.py | 29 +++- .../custom_datasets/toxic_conversation.py | 65 ++++++++ .../custom_datasets/translation.py | 157 ++++++++++++++++++ .../tests/test_datasets.py | 7 +- 11 files changed, 399 insertions(+), 18 deletions(-) create mode 100644 model/supervised_finetuning/custom_datasets/README.md create mode 100644 model/supervised_finetuning/custom_datasets/toxic_conversation.py create mode 100644 model/supervised_finetuning/custom_datasets/translation.py diff --git a/model/reward/instructor/configs/deberta-v3-base.yml b/model/reward/instructor/configs/deberta-v3-base.yml index 7023709c..134cfdaa 100644 --- a/model/reward/instructor/configs/deberta-v3-base.yml +++ b/model/reward/instructor/configs/deberta-v3-base.yml @@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base learning_rate: 1e-5 scheduler: cosine gradient_checkpointing: false -gradient_accumulation_steps: 32 +gradient_accumulation_steps: 16 per_device_train_batch_size: 2 warmup_steps: 600 eval_steps: 200 diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 2eaa6686..42a0ae2c 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -29,6 +29,14 @@ defaults: - soda - joke - gsm8k + - dive_mt + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja + - instruct_tuning + - wmt2019_de-en - samsum cache_dir: .cache loss_fn: CrossEntropyLoss diff --git a/model/supervised_finetuning/custom_datasets/README.md b/model/supervised_finetuning/custom_datasets/README.md new file mode 100644 index 00000000..9c825932 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/README.md @@ -0,0 +1,26 @@ +# Dataset collections overview: + +currently dataset can be divided into 3 classes + +- language knowledge + + - summarization + + - translation + +- dialogue : don't let user know you are a robot + +- STEM : knowledge about the world + + - coding + + - world knowledge <= ideally we want to handle this via prefix context + +Issues and TODO: + +* as dataset are growing, how can we update this section less + +* ideally we can update the config yaml and new dataset will be download from hub + + * one possible idea is we upload the trasform format of these dataset to the OA hub + diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index e293af3d..cb844777 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,11 +1,26 @@ -from custom_datasets.prompt_dialogue import PromptGeneratedDataset +""" + High level functions for model training +""" +from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT from custom_datasets.summarization import SummarizationDataset +from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination +from custom_datasets.translation import WMT2019, DiveMT, TEDTalk from sklearn.model_selection import train_test_split from torch.utils.data import Subset QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"] +SUMMARIZATION_DATASETS = [ + "xsum", + "cnn_dailymail", + "samsum", + "multi_news", + "scitldr", + "billsum", + "debate_sum", + "tldr_news", +] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"] def train_val_dataset(dataset, val_split=0.2): @@ -25,20 +40,43 @@ def get_one_dataset(conf, dataset_name): elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - val_name = "validation" if dataset_name not in ["billsum"] else "test" - eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + if dataset_name == "debate_sum": + train, eval = train_val_dataset(train, val_split=0.2) + else: + val_name = "validation" if dataset_name not in ["billsum"] else "test" + eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + elif "ted_trans" in dataset_name: + language_pair = dataset_name.split("_")[-1] + dataset = TEDTalk(pair=language_pair, split="train") + train, eval = train_val_dataset(dataset, val_split=0.2) + elif "wmt2019" in dataset_name: + language_pair = dataset_name.split("_")[-1] + train = WMT2019(pair=language_pair, split="train") + eval = WMT2019(pair=language_pair, split="validation") + elif dataset_name == "dive_mt": + dataset = DiveMT() + train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": dataset = PromptGeneratedDataset(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "prosocial_dialogue": + train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") + elif dataset_name == "explain_prosocial": + train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") elif dataset_name == "soda": dataset = SODA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.1) elif dataset_name == "joke": dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "instruct_tuning": + dataset = InstructionTuning(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 2efe160f..719fa0d6 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -25,6 +25,7 @@ class DialogueDataCollator: for feature_one in features: assert len(feature_one) % 2 == 0, "Number of messages must be even" + # TODO: we should push this to dataset __getitem__ messages = [ (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") + x diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 372ea27f..4a1d83a3 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -1,3 +1,4 @@ +import json import os from urllib.request import urlopen @@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset): we are ignoring results with multiple lines for now """ + name = "prompt_dialogue" url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt" def __init__(self, cache_dir) -> None: @@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +class InstructionTuning(Dataset): + """ + We have seen some promising capabilities from instruction tuning + with the following mix of datasets that are derived from datasets + available online. + The files for this data are in json format as a list of tuples + where each tuple is (source,instruction_response_pair) + + - instruction_tuning_dataset_alpha_part1.json + - instruction_tuning_dataset_alpha_part2.json + + Not to be confused with unatural instruction + """ + + name = "instruction_dataset" + url_part_2 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json" + ) + url_part_1 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json" + ) + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self.pairs = [] + for file_link in [self.url_part_1, self.url_part_2]: + basename = file_link.split("/")[-1] + instruction_tune_file = os.path.join(cache_dir, basename) + if not os.path.exists(instruction_tune_file): + with urlopen(file_link) as file: + content = file.read().decode() + with open(instruction_tune_file, "w", encoding="utf-8") as fout: + fout.write(content) + + with open(instruction_tune_file, "r", encoding="utf-8") as f: + datasets = json.load(f) + for row in datasets: + _, response_pair = row + question, answer = response_pair.split("\n\n", maxsplit=1) + answer = answer.replace("<|endoftext|>", "").strip() + self.pairs.append((question, answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index eed9c644..789b8f58 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -1,11 +1,18 @@ +""" + Open / close book QA datasets +""" import json import os +import re from urllib.request import urlopen import numpy as np from datasets import load_dataset from torch.utils.data import Dataset +# @agoryuno contributed this +re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") + QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} @@ -75,6 +82,9 @@ class QADataset(Dataset): class WebGPT(Dataset): + + name = "webgpt" + def __init__(self) -> None: super().__init__() @@ -89,7 +99,9 @@ class WebGPT(Dataset): self.index2question[len(self.index2question)] = question # only keep the best answer - questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + questions[question] = re_reference_remove.sub( + "", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + ) self.questions = questions @@ -103,6 +115,9 @@ class WebGPT(Dataset): class SODA(Dataset): + + name = "soda" + def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] @@ -149,8 +164,8 @@ class SODA(Dataset): class JokeExplaination(Dataset): - """ """ + name = "joke" url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" def __init__(self, cache_dir) -> None: @@ -182,3 +197,6 @@ class JokeExplaination(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +# https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 69e4b51d..2a097fe7 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -1,3 +1,6 @@ +""" + Summarize different spectrum of documents +""" import random from datasets import load_dataset @@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = { } summarization_config_mapping = { - "cnn_dailymail": ("3.0.0",), - "samsum": (), - "xsum": (), - "multi_news": (), - "scitldr": ("AIC",), - "billsum": (), - "reddit": (), + "cnn_dailymail": ( + "cnn_dailymail", + "3.0.0", + ), + "samsum": ("samsum",), + "xsum": ("xsum",), + "multi_news": ("multi_news",), + "scitldr": ( + "scitldr", + "AIC", + ), + "billsum": ("billsum",), + "reddit": ("reddit",), + "tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news + "debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum } summarization_name_mapping = { @@ -29,6 +40,8 @@ summarization_name_mapping = { "scitldr": ("source", "target"), "billsum": ("text", "summary"), "reddit": ("content", "summary"), + "tldr_news": ("content", "headline"), + "debate_sum": ("Full-Document", "Extract"), } @@ -43,7 +56,7 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split): self.name = dataset - self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py new file mode 100644 index 00000000..6ef29163 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -0,0 +1,65 @@ +''' + SFT dataset to reject toxic questions + +''' +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +class ProsocialDialogueExplaination(Dataset): + name = "prosocial_explain" + TEMPLATE = [ + # 0 : reply or sentence of interest, 1 : reason of caution + ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), + ("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"), + ("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"), + ("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"), + ] + def __init__(self, split='train', cache_dir='.cache') -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']): + (prompt_template, answer_template) = random.choice(self.TEMPLATE) + self.pairs.append(( + prompt_template.format(row['context'],safety_annotation), + answer_template.format( safe_answer, safety_annotation) + )) + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] + +class ProsocialDialogue(Dataset): + name = "prosocial_dialogue" + ''' + ProsocialDialog, we set up a human-AI collaborative data creation framework, + where GPT-3 generates the potentially unsafe utterances, and crowdworkers + provide prosocial responses to them. This approach allows us to circumvent + two substantial challenges: + (1) there are no available large-scale corpora of multiturn prosocial conversations + between humans + (2) asking humans to write unethical, toxic, or problematic utterances could result + in psychological harms (Roberts, 2017; Steiger et al., 2021). + ''' + PREFIX = "You are now a prosocial chatbot, be caution and casual when reply" + + + def __init__(self, split='train', cache_dir='.cache') -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for answer in row['rots']: + self.pairs.append(( + self.PREFIX+row['context'], + answer + )) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py new file mode 100644 index 00000000..a6d46e9e --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -0,0 +1,157 @@ +''' + List of translation dataset + + GroNLP/divemt + + fill in the blanks : https://huggingface.co/datasets/m_lama + +''' +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +# postfix prompt +TRANSLATION_PROMPT = { + "zh": [ # simplified or any chinese which was not mentioned + "Translate to chinese simplified: {}", + "{}, translate to chinese", + "{} give me the chinese translation", + "翻译成中文: {}", + "{} 这句中文翻译怎麽写?", + "我需要这句话的中文翻译: {}", + ], + "zh-tw": [ # WMT code + "{}. Translate to chinese traditional", + "{}, translate to chinese", + "{}. get chinese translation", + "中文翻譯: {}", + "幫我翻譯成中文: '{}'", + "{} 這句中文翻譯怎麼寫?", + ], + "ja": [ + "{}: help me translate to japanese", + "Need japanese translation: {}", + "{}: にほんごやくをよこす", + "{}: にほんごやくをおくれ", + "{}: にほんごやくを じょす", + "give me the japanese translation, {}", + ], + "de": [ + "{}: translate to german", + "give me the german translation {}", + "I want german translation {}", + "{}, ins Deutsche übersetzen", + "{}, Übersetzen ins Deutsche", + ], + "fr": [ + "{}. translate to french", + "{} write in french", + "{} french translation", + "{} ,donnez moi la traduction française"], + "ko": [ + "{}. translate to Korean", + "how do we write in korean: {}", + "give me the korean translation: {}", + "{}, 한국어 번역을 해주세요", + ], + "ms": [ + "{} translate to malay", + "{} how do we write in Malay", + "{} give me the malay translation", + "{} , berikan saya terjemahan dalam bahasa melayu", + "{}, Jemahan di bahasa melayu" + "{}, jemahkan ayat ini kepada bahasa melayu" + ], + "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], + "tr": ["{}. translate to turkish", "{} write in turkish", "turkish translation: '{}'"], + "it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"], + "nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"], + "vi": ["{}. translate to vietnamese", "{} write in vietnamese", "vietnamese translation: '{}'"], + "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], +} +class TranslationPair(Dataset): + def __init__(self) -> None: + super().__init__() + self.pairs = [] + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + return self.pairs[index] + + +class WMT2019(TranslationPair): + + def __init__(self, pair='zh-en', split='train') -> None: + super().__init__() + dataset = load_dataset('wmt19', pair)[split] + self.pairs = [] + src, tgt = pair.split('-') + for row in dataset: + row = row['translation'] + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[tgt] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + source = random.choice( + TRANSLATION_PROMPT[src] + ).format(row[tgt]) + self.pairs.append((source, row[src])) + +class DiveMT(TranslationPair): + + REMAP = { + 'tur': 'tr', + 'ita': 'it', + 'ukr': 'uk', + 'nld': 'nl', + 'vie': 'vi', + 'ara': 'ar' + } + + def __init__(self, split='train') -> None: + super().__init__() + dataset = load_dataset('GroNLP/divemt', 'main')[split] + tgt, src = 'tgt_text', 'src_text' + for row in dataset: + # ISO 639-2 + lang_code_2 = row['subject_id'].split('_')[0] + lang_code = self.REMAP[lang_code_2] + if lang_code not in TRANSLATION_PROMPT: + continue + + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[lang_code] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + lang_code = 'en' + source = random.choice( + TRANSLATION_PROMPT[lang_code] + ).format(row[tgt]) + self.pairs.append((source, row[src])) + + +class TEDTalk(TranslationPair): + # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean + + def __init__(self, pair='de-ja', split='train', year='2016') -> None: + super().__init__() + dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split] + src, tgt = pair.split('-') + for row in dataset: + row = row['translation'] + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[tgt] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + source = random.choice( + TRANSLATION_PROMPT[src] + ).format(row[tgt]) + self.pairs.append((source, row[src])) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index c9363303..2ac43613 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") - for dataset_name in others + qa_base + summarize_base: + for dataset_name in translation: print(dataset_name) train, eval = get_one_dataset(config, dataset_name) # sanity check @@ -51,4 +52,4 @@ def test_collate_fn(): if __name__ == "__main__": - test_collate_fn() + test_all_datasets() From 6cd62e3d488afc338dcbb0f5f0cda209480fd730 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 20 Jan 2023 06:16:26 +0000 Subject: [PATCH 06/15] [fix] Fix missing russian and update readme --- model/supervised_finetuning/README.md | 17 +++++++++++++++++ .../custom_datasets/translation.py | 5 +++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index 822121d8..9f200847 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -60,6 +60,23 @@ python trainer.py --configs defaults your-model-name --deepspeed ## Dataset choices +To specify which translation pair for [WMT](https://huggingface.co/datasets/wmt19) and [TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply add the supported language pair at the postfix + +``` + datasets: + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja +``` + +Currently only these languages are supported via prompt translation: + +``` +ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh +``` + ## Results Experimental results in wandb diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index a6d46e9e..79fff0d1 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -63,10 +63,11 @@ TRANSLATION_PROMPT = { "{}, jemahkan ayat ini kepada bahasa melayu" ], "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], - "tr": ["{}. translate to turkish", "{} write in turkish", "turkish translation: '{}'"], + "ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"], + "tr": ["{}. türkçeye çevi̇ri̇n", "{} write in turkish", "turkish translation: '{}'", "türkçeye çevi̇rmek: {}"], "it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"], "nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"], - "vi": ["{}. translate to vietnamese", "{} write in vietnamese", "vietnamese translation: '{}'"], + "vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"], "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], } class TranslationPair(Dataset): From 22e3ab1a890876691ab382e811b6a486f2fa3eeb Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 20 Jan 2023 07:23:02 +0000 Subject: [PATCH 07/15] [fix] linter fix --- model/supervised_finetuning/README.md | 5 +- .../custom_datasets/README.md | 19 ++-- .../custom_datasets/__init__.py | 2 +- .../custom_datasets/toxic_conversation.py | 37 ++++---- .../custom_datasets/translation.py | 86 ++++++++----------- 5 files changed, 70 insertions(+), 79 deletions(-) diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index 9f200847..d5b10e01 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -60,7 +60,10 @@ python trainer.py --configs defaults your-model-name --deepspeed ## Dataset choices -To specify which translation pair for [WMT](https://huggingface.co/datasets/wmt19) and [TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply add the supported language pair at the postfix +To specify which translation pair for +[WMT](https://huggingface.co/datasets/wmt19) and +[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply +add the supported language pair at the postfix ``` datasets: diff --git a/model/supervised_finetuning/custom_datasets/README.md b/model/supervised_finetuning/custom_datasets/README.md index 9c825932..56a28574 100644 --- a/model/supervised_finetuning/custom_datasets/README.md +++ b/model/supervised_finetuning/custom_datasets/README.md @@ -4,23 +4,24 @@ currently dataset can be divided into 3 classes - language knowledge - - summarization + - summarization - - translation + - translation - dialogue : don't let user know you are a robot - STEM : knowledge about the world - - coding - - - world knowledge <= ideally we want to handle this via prefix context + - coding + + - world knowledge <= ideally we want to handle this via prefix context Issues and TODO: -* as dataset are growing, how can we update this section less +- as dataset are growing, how can we update this section less -* ideally we can update the config yaml and new dataset will be download from hub - - * one possible idea is we upload the trasform format of these dataset to the OA hub +- ideally we can update the config yaml and new dataset will be download from + hub + - one possible idea is we upload the trasform format of these dataset to the + OA hub diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index cef3a409..2e1e4b30 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -2,7 +2,7 @@ High level functions for model training """ from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset -from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT, SODADialogue +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination from custom_datasets.translation import WMT2019, DiveMT, TEDTalk diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py index 6ef29163..815ac722 100644 --- a/model/supervised_finetuning/custom_datasets/toxic_conversation.py +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -1,11 +1,13 @@ -''' +""" SFT dataset to reject toxic questions -''' +""" import random + from datasets import load_dataset from torch.utils.data import Dataset + class ProsocialDialogueExplaination(Dataset): name = "prosocial_explain" TEMPLATE = [ @@ -15,26 +17,31 @@ class ProsocialDialogueExplaination(Dataset): ("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"), ("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"), ] - def __init__(self, split='train', cache_dir='.cache') -> None: + + def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: - for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']): + for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]): (prompt_template, answer_template) = random.choice(self.TEMPLATE) - self.pairs.append(( - prompt_template.format(row['context'],safety_annotation), - answer_template.format( safe_answer, safety_annotation) - )) + self.pairs.append( + ( + prompt_template.format(row["context"], safety_annotation), + answer_template.format(safe_answer, safety_annotation), + ) + ) + def __len__(self): return len(self.pairs) def __getitem__(self, idx): return self.pairs[idx] + class ProsocialDialogue(Dataset): name = "prosocial_dialogue" - ''' + """ ProsocialDialog, we set up a human-AI collaborative data creation framework, where GPT-3 generates the potentially unsafe utterances, and crowdworkers provide prosocial responses to them. This approach allows us to circumvent @@ -43,20 +50,16 @@ class ProsocialDialogue(Dataset): between humans (2) asking humans to write unethical, toxic, or problematic utterances could result in psychological harms (Roberts, 2017; Steiger et al., 2021). - ''' + """ PREFIX = "You are now a prosocial chatbot, be caution and casual when reply" - - def __init__(self, split='train', cache_dir='.cache') -> None: + def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: - for answer in row['rots']: - self.pairs.append(( - self.PREFIX+row['context'], - answer - )) + for answer in row["rots"]: + self.pairs.append((self.PREFIX + row["context"], answer)) def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 79fff0d1..694d31ce 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -1,18 +1,19 @@ -''' +""" List of translation dataset GroNLP/divemt fill in the blanks : https://huggingface.co/datasets/m_lama -''' +""" import random + from datasets import load_dataset from torch.utils.data import Dataset # postfix prompt TRANSLATION_PROMPT = { - "zh": [ # simplified or any chinese which was not mentioned + "zh": [ # simplified or any chinese which was not mentioned "Translate to chinese simplified: {}", "{}, translate to chinese", "{} give me the chinese translation", @@ -20,7 +21,7 @@ TRANSLATION_PROMPT = { "{} 这句中文翻译怎麽写?", "我需要这句话的中文翻译: {}", ], - "zh-tw": [ # WMT code + "zh-tw": [ # WMT code "{}. Translate to chinese traditional", "{}, translate to chinese", "{}. get chinese translation", @@ -47,7 +48,8 @@ TRANSLATION_PROMPT = { "{}. translate to french", "{} write in french", "{} french translation", - "{} ,donnez moi la traduction française"], + "{} ,donnez moi la traduction française", + ], "ko": [ "{}. translate to Korean", "how do we write in korean: {}", @@ -59,8 +61,7 @@ TRANSLATION_PROMPT = { "{} how do we write in Malay", "{} give me the malay translation", "{} , berikan saya terjemahan dalam bahasa melayu", - "{}, Jemahan di bahasa melayu" - "{}, jemahkan ayat ini kepada bahasa melayu" + "{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu", ], "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], "ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"], @@ -70,6 +71,8 @@ TRANSLATION_PROMPT = { "vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"], "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], } + + class TranslationPair(Dataset): def __init__(self) -> None: super().__init__() @@ -80,79 +83,60 @@ class TranslationPair(Dataset): def __getitem__(self, index): return self.pairs[index] - + class WMT2019(TranslationPair): - - def __init__(self, pair='zh-en', split='train') -> None: + def __init__(self, pair="zh-en", split="train") -> None: super().__init__() - dataset = load_dataset('wmt19', pair)[split] + dataset = load_dataset("wmt19", pair)[split] self.pairs = [] - src, tgt = pair.split('-') + src, tgt = pair.split("-") for row in dataset: - row = row['translation'] + row = row["translation"] if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[tgt] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - source = random.choice( - TRANSLATION_PROMPT[src] - ).format(row[tgt]) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) + class DiveMT(TranslationPair): - REMAP = { - 'tur': 'tr', - 'ita': 'it', - 'ukr': 'uk', - 'nld': 'nl', - 'vie': 'vi', - 'ara': 'ar' - } + REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"} - def __init__(self, split='train') -> None: + def __init__(self, split="train") -> None: super().__init__() - dataset = load_dataset('GroNLP/divemt', 'main')[split] - tgt, src = 'tgt_text', 'src_text' + dataset = load_dataset("GroNLP/divemt", "main")[split] + tgt, src = "tgt_text", "src_text" for row in dataset: # ISO 639-2 - lang_code_2 = row['subject_id'].split('_')[0] + lang_code_2 = row["subject_id"].split("_")[0] lang_code = self.REMAP[lang_code_2] if lang_code not in TRANSLATION_PROMPT: continue if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[lang_code] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - lang_code = 'en' - source = random.choice( - TRANSLATION_PROMPT[lang_code] - ).format(row[tgt]) + else: # translating in reverse direction + lang_code = "en" + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt]) self.pairs.append((source, row[src])) class TEDTalk(TranslationPair): # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean - def __init__(self, pair='de-ja', split='train', year='2016') -> None: + def __init__(self, pair="de-ja", split="train", year="2016") -> None: super().__init__() - dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split] - src, tgt = pair.split('-') + dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split] + src, tgt = pair.split("-") for row in dataset: - row = row['translation'] + row = row["translation"] if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[tgt] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - source = random.choice( - TRANSLATION_PROMPT[src] - ).format(row[tgt]) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) From aca3e9de89db216e1d2726ffcdf2d6eaabf51a8a Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 20 Jan 2023 07:26:26 +0000 Subject: [PATCH 08/15] [fix] wait it pass? --- model/supervised_finetuning/tests/test_datasets.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 2ac43613..3b59f289 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -11,7 +11,7 @@ def test_all_datasets(): translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") - for dataset_name in translation: + for dataset_name in translation + others + summarize_base + qa_base: print(dataset_name) train, eval = get_one_dataset(config, dataset_name) # sanity check @@ -49,7 +49,3 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: assert batch["targets"].shape[1] <= 512 - - -if __name__ == "__main__": - test_all_datasets() From 9b659191ca5759d56d159331d134d4bde86c4aa7 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 20 Jan 2023 09:37:02 +0000 Subject: [PATCH 09/15] Extract common getStaticProps to single file --- website/src/lib/default_static_props.ts | 7 +++++++ website/src/pages/404.tsx | 8 +------- website/src/pages/500.tsx | 8 +------- website/src/pages/about.tsx | 8 +------- website/src/pages/account/edit.tsx | 8 +------- website/src/pages/account/index.tsx | 8 +------- website/src/pages/admin/index.tsx | 8 +------- website/src/pages/create/assistant_reply.tsx | 8 +------- website/src/pages/create/initial_prompt.tsx | 8 +------- website/src/pages/create/user_reply.tsx | 8 +------- website/src/pages/dashboard.tsx | 8 +------- website/src/pages/evaluate/rank_assistant_replies.tsx | 8 +------- website/src/pages/evaluate/rank_initial_prompts.tsx | 8 +------- website/src/pages/evaluate/rank_user_replies.tsx | 8 +------- website/src/pages/label/label_assistant_reply.tsx | 8 +------- website/src/pages/label/label_initial_prompt.tsx | 8 +------- website/src/pages/label/label_prompter_reply.tsx | 8 +------- website/src/pages/leaderboard.tsx | 8 +------- website/src/pages/messages/index.tsx | 8 +------- website/src/pages/privacy-policy.tsx | 8 +------- website/src/pages/terms-of-service.tsx | 8 +------- 21 files changed, 27 insertions(+), 140 deletions(-) create mode 100644 website/src/lib/default_static_props.ts diff --git a/website/src/lib/default_static_props.ts b/website/src/lib/default_static_props.ts new file mode 100644 index 00000000..3cc311e8 --- /dev/null +++ b/website/src/lib/default_static_props.ts @@ -0,0 +1,7 @@ +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; + +export const getDefaultStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale)), + }, +}); \ No newline at end of file diff --git a/website/src/pages/404.tsx b/website/src/pages/404.tsx index 11185057..d4c58b54 100644 --- a/website/src/pages/404.tsx +++ b/website/src/pages/404.tsx @@ -1,9 +1,9 @@ import { Box, Button, Center, Link, Text } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { FiAlertTriangle } from "react-icons/fi"; import { EmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; function Error() { return ( @@ -41,10 +41,4 @@ function Error() { Error.getLayout = getTransparentHeaderLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default Error; diff --git a/website/src/pages/500.tsx b/website/src/pages/500.tsx index bd0fac2f..378bdfff 100644 --- a/website/src/pages/500.tsx +++ b/website/src/pages/500.tsx @@ -1,9 +1,9 @@ import { Box, Button, Center, Link, Text } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { FiAlertTriangle } from "react-icons/fi"; import { EmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; function ServerError() { return ( @@ -44,10 +44,4 @@ function ServerError() { ServerError.getLayout = getTransparentHeaderLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default ServerError; diff --git a/website/src/pages/about.tsx b/website/src/pages/about.tsx index 08d2bea7..01182be0 100644 --- a/website/src/pages/about.tsx +++ b/website/src/pages/about.tsx @@ -1,10 +1,10 @@ import Image from "next/image"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { CallToAction } from "src/components/CallToAction"; import { Container } from "src/components/Container"; import Roadmap from "src/components/Roadmap"; import Services from "src/components/Services"; import Vision from "src/components/Vision"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const AboutPage = () => { return ( @@ -37,10 +37,4 @@ const AboutPage = () => { ); }; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default AboutPage; diff --git a/website/src/pages/account/edit.tsx b/website/src/pages/account/edit.tsx index 9120322a..52af7e5e 100644 --- a/website/src/pages/account/edit.tsx +++ b/website/src/pages/account/edit.tsx @@ -2,9 +2,9 @@ import { Button, Input, InputGroup } from "@chakra-ui/react"; import Head from "next/head"; import Router from "next/router"; import { useSession } from "next-auth/react"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import React from "react"; import { Control, useForm, useWatch } from "react-hook-form"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; export default function Account() { const { data: session } = useSession(); @@ -31,12 +31,6 @@ export default function Account() { ); } -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - const EditForm = () => { const { data: session } = useSession(); diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx index b6e95594..813ce8ea 100644 --- a/website/src/pages/account/index.tsx +++ b/website/src/pages/account/index.tsx @@ -2,8 +2,8 @@ import { Button } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; import { useSession } from "next-auth/react"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import React from "react"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; export default function Account() { const { data: session } = useSession(); @@ -32,9 +32,3 @@ export default function Account() { ); } - -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index ef5abe2c..f8827049 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { useEffect } from "react"; import { getAdminLayout } from "src/components/Layout"; import UsersCell from "src/components/UsersCell"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; /** * Provides the admin index page that will display a list of users and give @@ -45,10 +45,4 @@ const AdminIndex = () => { AdminIndex.getLayout = getAdminLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default AdminIndex; diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index 9002ad3e..1c83eb23 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); @@ -30,10 +30,4 @@ const AssistantReply = () => { AssistantReply.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default AssistantReply; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 496e000c..639df68f 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); @@ -30,10 +30,4 @@ const InitialPrompt = () => { InitialPrompt.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default InitialPrompt; diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 86f29826..5898439c 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); @@ -30,10 +30,4 @@ const UserReply = () => { UserReply.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default UserReply; diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index ed2b20e1..3d6beb8b 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,9 +1,9 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const Dashboard = () => { return ( @@ -23,10 +23,4 @@ const Dashboard = () => { Dashboard.getLayout = (page) => getDashboardLayout(page); -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default Dashboard; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index dcd2d30b..da79d92f 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); @@ -30,10 +30,4 @@ const RankAssistantReplies = () => { RankAssistantReplies.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default RankAssistantReplies; diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 4c997faa..f23fc0ed 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); @@ -30,10 +30,4 @@ const RankInitialPrompts = () => { RankInitialPrompts.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default RankInitialPrompts; diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index 0982d36d..cee82b87 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); @@ -30,10 +30,4 @@ const RankUserReplies = () => { RankUserReplies.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default RankUserReplies; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 6f005f33..07a6cb1c 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelAssistantReply = () => { const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask(); @@ -30,10 +30,4 @@ const LabelAssistantReply = () => { LabelAssistantReply.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default LabelAssistantReply; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index a6813499..8735044f 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelInitialPrompt = () => { const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask(); @@ -30,10 +30,4 @@ const LabelInitialPrompt = () => { LabelInitialPrompt.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default LabelInitialPrompt; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index f1ba8008..17164e11 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,10 +1,10 @@ import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelPrompterReply = () => { const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask(); @@ -30,10 +30,4 @@ const LabelPrompterReply = () => { LabelPrompterReply.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default LabelPrompterReply; diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index d6bae8e9..f79dac52 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -1,8 +1,8 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; const Leaderboard = () => { @@ -46,10 +46,4 @@ const Leaderboard = () => { Leaderboard.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default Leaderboard; diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 40497fd1..3b6e342e 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -1,10 +1,10 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { MessageTable } from "src/components/Messages/MessageTable"; import { get } from "src/lib/api"; import useSWRImmutable from "swr/immutable"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const MessagesDashboard = () => { const boxBgColor = useColorModeValue("white", "gray.800"); @@ -55,10 +55,4 @@ const MessagesDashboard = () => { MessagesDashboard.getLayout = getDashboardLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default MessagesDashboard; diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx index 2f171483..f84dc1e8 100644 --- a/website/src/pages/privacy-policy.tsx +++ b/website/src/pages/privacy-policy.tsx @@ -1,9 +1,9 @@ import { Box, Heading, Link, Stack, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const PrivacyPolicy = () => { const backgroundColor = useColorModeValue("gray.100", "gray.800"); @@ -225,10 +225,4 @@ const PrivacyPolicy = () => { PrivacyPolicy.getLayout = getTransparentHeaderLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default PrivacyPolicy; diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx index 3a414292..41269bdf 100644 --- a/website/src/pages/terms-of-service.tsx +++ b/website/src/pages/terms-of-service.tsx @@ -1,9 +1,9 @@ import { Box, Heading, Stack } from "@chakra-ui/react"; import Head from "next/head"; -import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const TermsOfService = () => { const TermsData = [ @@ -190,10 +190,4 @@ const TermsOfService = () => { TermsOfService.getLayout = getTransparentHeaderLayout; -export const getStaticProps = async ({ locale }) => ({ - props: { - ...(await serverSideTranslations(locale, ["common"])), - }, -}); - export default TermsOfService; From 47c402e7620430574c7027b93fc80e9e7bd28d79 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 10:37:54 +0100 Subject: [PATCH 10/15] More space for messages on narrow screen --- website/src/components/FlaggableElement.tsx | 8 +-- .../src/components/Messages/MessageTable.tsx | 2 +- .../components/Messages/MessageTableEntry.tsx | 67 ++++++++++--------- website/src/components/Tasks/TaskTypes.tsx | 2 +- 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 9ba227ee..7e28f2c2 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -25,8 +25,8 @@ import clsx from "clsx"; import { useEffect, useReducer } from "react"; import { FiAlertCircle } from "react-icons/fi"; import { get, post } from "src/lib/api"; -import { Message } from "src/types/Conversation"; import { colors } from "src/styles/Theme/colors"; +import { Message } from "src/types/Conversation"; import useSWR from "swr"; import useSWRMutation from "swr/mutation"; @@ -114,9 +114,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => { }, [data, isLoading]); const { trigger } = useSWRMutation("/api/set_label", post, { - onSuccess: () => { - setIsEditing.off(); - }, + onSuccess: setIsEditing.off, }); const submitResponse = () => { @@ -149,7 +147,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => { isLazy lazyBehavior="keepMounted" > - + {props.children} diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index 45a13d2f..ed98752c 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -9,7 +9,7 @@ interface MessageTableProps { export function MessageTable({ messages, enableLink }: MessageTableProps) { return ( - + {messages.map((item) => ( ))} diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 8e9d03b6..d18bd910 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,6 +1,8 @@ -import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; +import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import Link from "next/link"; +import { useRouter } from "next/router"; +import { useCallback, useMemo } from "react"; import { FlaggableElement } from "src/components/FlaggableElement"; import { Message } from "src/types/Conversation"; @@ -10,47 +12,48 @@ interface MessageTableEntryProps { } export function MessageTableEntry(props: MessageTableEntryProps) { + const router = useRouter(); + const { item } = props; + + const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]); + const backgroundColor = useColorModeValue("gray.100", "gray.700"); const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B"); - const avatarColor = useColorModeValue("white", "black"); const borderColor = useColorModeValue("blackAlpha.200", "whiteAlpha.200"); + const inlineAvatar = useBreakpointValue({ base: true, sm: false }); + + const avatar = useMemo( + () => ( + + ), + [borderColor, inlineAvatar, item.is_assistant] + ); + return ( - - + {!inlineAvatar && avatar} + + {inlineAvatar && avatar} + {item.text} - {props.enabled ? ( - - - - {item.text} - - - - ) : ( - - {item.text} - - )} ); diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 868a9fb8..b58de0f8 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -121,7 +121,7 @@ export const TaskTypes: TaskInfo[] = [ category: TaskCategory.Label, pathname: "/label/label_prompter_reply", help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_prompter_reply", - overview: "Given the following discussion, provide labels for the final prompt", + overview: "Given the following discussion, provide labels for the final prompt.", type: "label_prompter_reply", mode: "full", update_type: "text_labels", From f09eee1ce7265cd617c8d34ef67be0997ff7c8fd Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 20 Jan 2023 09:44:06 +0000 Subject: [PATCH 11/15] Pre-commit --- website/src/lib/default_static_props.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/lib/default_static_props.ts b/website/src/lib/default_static_props.ts index 3cc311e8..365099cf 100644 --- a/website/src/lib/default_static_props.ts +++ b/website/src/lib/default_static_props.ts @@ -4,4 +4,4 @@ export const getDefaultStaticProps = async ({ locale }) => ({ props: { ...(await serverSideTranslations(locale)), }, -}); \ No newline at end of file +}); From 1fe24db8ae6bb92486082bea40085f57d09d9df9 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 11:05:56 +0100 Subject: [PATCH 12/15] Restyle account page --- website/public/locales/en/common.json | 2 +- website/src/components/Survey/SurveyCard.tsx | 12 +++---- website/src/pages/account/index.tsx | 35 ++++++++++++++------ 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index e18eb8ec..99edf6c3 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -1,6 +1,6 @@ { "about": "About", - "account_settings": "Account Settings", + "account_settings": "Account", "connect": "Connect", "conversational": "Conversational AI for everyone.", "dashboard": "Dashboard", diff --git a/website/src/components/Survey/SurveyCard.tsx b/website/src/components/Survey/SurveyCard.tsx index 5a78ce2b..94657fc7 100644 --- a/website/src/components/Survey/SurveyCard.tsx +++ b/website/src/components/Survey/SurveyCard.tsx @@ -1,22 +1,18 @@ import { Box, BoxProps, useColorModeValue } from "@chakra-ui/react"; +import { PropsWithChildren } from "react"; -interface SurveyCardProps { - className?: string; - children: React.ReactNode; -} - -export const SurveyCard = (props: SurveyCardProps) => { +export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => { const backgroundColor = useColorModeValue("white", "gray.700"); const BoxClasses: BoxProps = { gap: "2", borderRadius: "xl", shadow: "base", - className: "p-4 sm:p-6", + className: "p-4 sm:p-6 " + (props.className ?? ""), }; return ( - + {props.children} ); diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx index 813ce8ea..88964b3b 100644 --- a/website/src/pages/account/index.tsx +++ b/website/src/pages/account/index.tsx @@ -1,9 +1,11 @@ -import { Button } from "@chakra-ui/react"; +import { Button, Divider, Flex, Grid, Icon, Text } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; import { useSession } from "next-auth/react"; import React from "react"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import { MdOutlineEdit } from "react-icons/md"; +import { SurveyCard } from "src/components/Survey/SurveyCard"; export default function Account() { const { data: session } = useSession(); @@ -20,15 +22,28 @@ export default function Account() { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
-
-

{session.user.name || "No username"}

- -

{session.user.email}

-
-
+
+ + + + Your Account + + + + Username + + {session.user.name ?? "(No username)"} + + + + + Email + {session.user.email ?? "(No Email)"} + +

+
+
+
); } From e80a69dd8a17a7dd703d6d3090b985d11df1f579 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 12:08:10 +0100 Subject: [PATCH 13/15] Fetch available tasks --- .../src/components/Dashboard/TaskOption.tsx | 14 +++++++------- website/src/components/Tasks/TaskTypes.tsx | 11 +++++++++-- website/src/hooks/tasks/useGenericTaskAPI.tsx | 9 ++++----- website/src/lib/oasst_api_client.ts | 8 ++++++++ website/src/pages/api/available_tasks.ts | 11 +++++++++++ website/src/pages/dashboard.tsx | 19 +++++++++++++++++-- website/src/pages/tasks/random.tsx | 3 ++- website/src/types/Task.ts | 4 ++++ 8 files changed, 62 insertions(+), 17 deletions(-) create mode 100644 website/src/pages/api/available_tasks.ts diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index 1b421126..e2bafac3 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -1,19 +1,19 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Link from "next/link"; -import { TaskTypes } from "../Tasks/TaskTypes"; +import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes"; -export const TaskOption = ({ displayTaskCategories }) => { +export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => { const backgroundColor = useColorModeValue("white", "gray.700"); return ( - {displayTaskCategories.map((category, categoryIndex) => ( -
- {category} + {displayTaskCategories.map((category) => ( +
+ {TaskCategoryLabels[category]} - {TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => ( - + {TaskTypes.filter((task) => task.category === category).map((item) => ( + (taskApiEndpoint: string) => { +export const useGenericTaskAPI = (taskType: TaskTypeEnum) => { type ConcreteTaskResponse = TaskResponse; const [tasks, setTasks] = useState([]); - const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, get, { + const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskType, get, { onSuccess: (data) => setTasks([data]), revalidateOnMount: true, dedupingInterval: 500, }); const { trigger } = useSWRMutation("/api/update_task", post, { - onSuccess: async (response) => { - const newTask: ConcreteTaskResponse = response; + onSuccess: async (newTask: ConcreteTaskResponse) => { setTasks((oldTasks) => [...oldTasks, newTask]); mutate(); }, diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index de58ae71..b1639462 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,5 +1,6 @@ import type { Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; +import type { AvailableTasks } from "src/types/Task"; import type { BackendUser, BackendUserCore } from "src/types/Users"; export class OasstError { @@ -205,6 +206,13 @@ export class OasstApiClient { async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise { return this.get(`/api/v1/leaderboards/${time_frame}`); } + + /** + * Returns the counts of all tasks (some might be zero) + */ + async fetch_available_tasks(user: BackendUserCore): Promise { + return this.post(`/api/v1/tasks/availability`, user); + } } const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/pages/api/available_tasks.ts b/website/src/pages/api/available_tasks.ts new file mode 100644 index 00000000..36265d25 --- /dev/null +++ b/website/src/pages/api/available_tasks.ts @@ -0,0 +1,11 @@ +import { withoutRole } from "src/lib/auth"; +import { oasstApiClient } from "src/lib/oasst_api_client"; +import { getBackendUserCore } from "src/lib/users"; + +const handler = withoutRole("banned", async (req, res, token) => { + const user = await getBackendUserCore(token.sub); + const availableTasks = await oasstApiClient.fetch_available_tasks(user); + res.status(200).json(availableTasks); +}); + +export default handler; diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 3d6beb8b..e0b8bba4 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,11 +1,20 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; +import { useMemo } from "react"; import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; +import { get } from "src/lib/api"; +import type { AvailableTasks, TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import useSWRImmutable from "swr/immutable"; const Dashboard = () => { + const { data } = useSWRImmutable("/api/available_tasks", get); + + // TODO: show only these tasks: + const availableTasks = useMemo(() => filterAvailableTasks(data ?? {}), [data]); + return ( <> @@ -14,13 +23,19 @@ const Dashboard = () => { - + ); }; -Dashboard.getLayout = (page) => getDashboardLayout(page); +Dashboard.getLayout = getDashboardLayout; export default Dashboard; + +const filterAvailableTasks = (availableTasks: Partial) => + Object.entries(availableTasks) + .filter(([_, count]) => count > 0) + .sort((a, b) => b[1] - a[1]) + .map(([taskType]) => taskType) as TaskType[]; diff --git a/website/src/pages/tasks/random.tsx b/website/src/pages/tasks/random.tsx index d2e850f5..be1809c3 100644 --- a/website/src/pages/tasks/random.tsx +++ b/website/src/pages/tasks/random.tsx @@ -4,9 +4,10 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { TaskType } from "src/types/Task"; const RandomTask = () => { - const { tasks, isLoading, trigger, reset } = useGenericTaskAPI("random"); + const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random); if (isLoading) { return ; diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index d58f892c..8e5ada44 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -10,6 +10,8 @@ export const enum TaskType { label_initial_prompt = "label_initial_prompt", label_prompter_reply = "label_prompter_reply", label_assistant_reply = "label_assistant_reply", + + random = "random", } // we need to reconsider how to handle task content types @@ -32,3 +34,5 @@ export interface TaskResponse { userId: string; task: Task; } + +export type AvailableTasks = { [taskType in TaskType]: number }; From a588844ba5d1674f48c9b33abf10c5ab6d1104fe Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Fri, 20 Jan 2023 12:26:50 +0100 Subject: [PATCH 14/15] use clsx --- website/src/components/Survey/SurveyCard.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/website/src/components/Survey/SurveyCard.tsx b/website/src/components/Survey/SurveyCard.tsx index 94657fc7..6101f787 100644 --- a/website/src/components/Survey/SurveyCard.tsx +++ b/website/src/components/Survey/SurveyCard.tsx @@ -1,4 +1,5 @@ import { Box, BoxProps, useColorModeValue } from "@chakra-ui/react"; +import clsx from "clsx"; import { PropsWithChildren } from "react"; export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => { @@ -8,7 +9,7 @@ export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => gap: "2", borderRadius: "xl", shadow: "base", - className: "p-4 sm:p-6 " + (props.className ?? ""), + className: clsx("p-4 sm:p-6", props.className), }; return ( From 70fc80aa0889865c374a89ee7fe1e94141f42cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 20 Jan 2023 16:32:13 +0100 Subject: [PATCH 15/15] Add keyset pagination for users ordered by `username` / `display_name` (#851) * add keyset pagination for user ordered by username or display_name * add index on display-name for user table * update down_revision in migration script --- ...f26fec4d204_add_ix_user_display_name_id.py | 26 +++++ .../oasst_backend/api/v1/frontend_users.py | 39 +++---- backend/oasst_backend/api/v1/users.py | 60 ++++++++++- backend/oasst_backend/models/user.py | 5 +- backend/oasst_backend/user_repository.py | 102 +++++++++++++----- 5 files changed, 180 insertions(+), 52 deletions(-) create mode 100644 backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py diff --git a/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py new file mode 100644 index 00000000..19b497fa --- /dev/null +++ b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py @@ -0,0 +1,26 @@ +"""add ix_user_display_name_id + +Revision ID: 4f26fec4d204 +Revises: 0964ac95170d +Create Date: 2023-01-19 22:00:00 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4f26fec4d204" +down_revision = "7f0a28a156f4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_user_display_name_id", table_name="user") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 0b2db515..f2fc3181 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/", response_model=list[protocol.FrontEndUser]) -def get_users( +@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True) +def get_users_ordered_by_username( api_client_id: Optional[UUID] = None, - max_count: Optional[int] = Query(100, gt=0, le=10000), - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): ur = UserRepository(db, api_client) - users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method) - return [u.to_protocol_frontend_user() for u in users] - - -@router.get("/by_display_name") -def query_frontend_users_by_display_name( - search_text: str, - exact: bool = False, - api_client_id: UUID = None, - max_count: int = Query(20, gt=0, le=1000), - auth_method: str = None, - api_client: ApiClient = Depends(deps.get_api_client), - db: Session = Depends(deps.get_db), -): - ur = UserRepository(db, api_client) - users = ur.query_users_by_display_name( - search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, ) return [u.to_protocol_frontend_user() for u in users] diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 36cd65c9..0b31495a 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/users/{user_id}", response_model=protocol.FrontEndUser) +@router.get("/by_username", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_username( + api_client_id: Optional[UUID] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, + auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/by_display_name", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_display_name( + api_client_id: Optional[UUID] = None, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + search_text: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_display_name( + api_client_id=api_client_id, + gte_display_name=gte_display_name, + gt_id=gt_id, + lte_display_name=lte_display_name, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/{user_id}", response_model=protocol.FrontEndUser) def get_user( user_id: UUID, api_client_id: UUID = None, @@ -31,7 +85,7 @@ def get_user( return user.to_protocol_frontend_user() -@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT) def update_user( user_id: UUID, enabled: Optional[bool] = None, @@ -46,7 +100,7 @@ def update_user( ur.update_user(user_id, enabled, notes) -@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT) def delete_user( user_id: UUID, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 0fb36c22..d882a15a 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel class User(SQLModel, table=True): __tablename__ = "user" - __table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),) + __table_args__ = ( + Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True), + Index("ix_user_display_name_id", "display_name", "id", unique=True), + ) id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 578dc5f1..c0c2a88d 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, and_, or_ from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -135,13 +135,16 @@ class UserRepository: self.db.add(user) return user - def query_users( + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None, - limit: Optional[int] = 20, - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -150,34 +153,52 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - users = self.db.query(User) + qry = self.db.query(User).order_by(User.username, User.id) - if api_client_id: - users = users.filter(User.api_client_id == api_client_id) + if gte_username is not None: + if gt_id: + qry = qry.filter( + or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id)) + ) + else: + qry = qry.filter(User.username >= gte_username) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_username is not None: + if lt_id: + qry = qry.filter( + or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id)) + ) + else: + qry = qry.filter(User.username <= lte_username) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) if auth_method: - users = users.filter(User.auth_method == auth_method) + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) - users = users.order_by(User.display_name) - - if gt: - users = users.filter(User.display_name > gt) - - if lt: - users = users.filter(User.display_name < lt) + if search_text: + pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) + qry = qry.filter(User.username.like(pattern)) if limit is not None: - users = users.limit(limit) + qry = qry.limit(limit) - return users.all() + return qry.all() - def query_users_by_display_name( + def query_users_ordered_by_display_name( self, - search_text: str, - exact: Optional[bool] = False, - limit: Optional[int] = 20, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, api_client_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -186,11 +207,40 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - qry = self.db.query(User).order_by(User.display_name) + qry = self.db.query(User).order_by(User.display_name, User.id) - if exact: - qry = qry.filter(User.display_name == search_text) - else: + if gte_display_name is not None: + if gt_id: + qry = qry.filter( + or_( + User.display_name > gte_display_name, + and_(User.display_name == gte_display_name, User.id > gt_id), + ) + ) + else: + qry = qry.filter(User.display_name >= gte_display_name) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_display_name is not None: + if lt_id: + qry = qry.filter( + or_( + User.display_name < lte_display_name, + and_(User.display_name == lte_display_name, User.id < lt_id), + ) + ) + else: + qry = qry.filter(User.display_name <= lte_display_name) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if auth_method: + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) + + if search_text: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.display_name.like(pattern))