From ec5bdef719b5b0b7c25e7441866500f22248fc13 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sun, 22 Jan 2023 21:36:52 +0900 Subject: [PATCH 001/101] Not sure this will work --- backend/oasst_backend/api/v1/api.py | 2 ++ backend/requirements.txt | 2 ++ docker-compose.yaml | 2 +- scripts/backend-development/run-local.sh | 1 + website/src/pages/api/auth/[...nextauth].ts | 11 +++++++++++ website/src/pages/dashboard.tsx | 3 +++ 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 2931ac05..003f039f 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from oasst_backend.api.v1 import ( admin, + auth, frontend_messages, frontend_users, hugging_face, @@ -23,3 +24,4 @@ api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"]) api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/requirements.txt b/backend/requirements.txt index 0f91315e..dff8d14c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,7 +6,9 @@ loguru==0.6.0 numpy==1.22.4 psycopg2-binary==2.9.5 pydantic==1.9.1 +pyjwt python-dotenv==0.21.0 +redis scipy==1.8.1 SQLAlchemy==1.4.41 sqlmodel==0.0.8 diff --git a/docker-compose.yaml b/docker-compose.yaml index 908457cd..60048763 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,7 +4,7 @@ services: # Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend. backend-dev: image: sverrirab/sleep - depends_on: [db, adminer, redis, redis-insights] + depends_on: [db, adminer, redis, redis-insights, webdb, maildev] # Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend. frontend-dev: diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 7366cde6..f81362dd 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -8,6 +8,7 @@ export DEBUG_USE_SEED_DATA=True export DEBUG_SKIP_TOXICITY_CALCULATION=True export DEBUG_ALLOW_SELF_LABELING=True export DEBUG_SKIP_EMBEDDING_COMPUTATION=True +export BACKEND_CORS_ORIGINS='["http://localhost:3000"]' uvicorn main:app --reload --port 8080 --host 0.0.0.0 diff --git a/website/src/pages/api/auth/[...nextauth].ts b/website/src/pages/api/auth/[...nextauth].ts index 3d3dbaa4..af2cbd59 100644 --- a/website/src/pages/api/auth/[...nextauth].ts +++ b/website/src/pages/api/auth/[...nextauth].ts @@ -148,6 +148,17 @@ export const authOptions: AuthOptions = { } }, }, + cookies: { + sessionToken: { + name: `next-auth.session-token`, + options: { + httpOnly: true, + sameSite: "none", + path: "/", + secure: true, + }, + }, + }, session: { strategy: "jwt", }, diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index e0b8bba4..2739821b 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -7,9 +7,12 @@ import { TaskCategory } from "src/components/Tasks/TaskTypes"; import { get } from "src/lib/api"; import type { AvailableTasks, TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import useSWR from "swr"; import useSWRImmutable from "swr/immutable"; const Dashboard = () => { + useSWR("http://localhost:8080/api/v1/auth/check", get); + const { data } = useSWRImmutable("/api/available_tasks", get); // TODO: show only these tasks: From 3c2d1086b80f354e52c10bd22c84aef0cb574e1b Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Mon, 23 Jan 2023 18:13:33 +0900 Subject: [PATCH 002/101] Adding a new auth route --- backend/oasst_backend/api/v1/auth.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 backend/oasst_backend/api/v1/auth.py diff --git a/backend/oasst_backend/api/v1/auth.py b/backend/oasst_backend/api/v1/auth.py new file mode 100644 index 00000000..a2386362 --- /dev/null +++ b/backend/oasst_backend/api/v1/auth.py @@ -0,0 +1,38 @@ +from typing import Union + +import jwt +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import APIKeyCookie +from pydantic import BaseModel + +router = APIRouter() + +oauth2_scheme = APIKeyCookie(name="next-auth.session-token") + +SECRET_KEY = "O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + + +class TokenData(BaseModel): + sub: Union[str, None] = None + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + print("get_current_user") + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + print(payload) + sub: str = payload.get("sub") + if sub is None: + raise credentials_exception + return TokenData(sub=sub) + + +@router.get("/check", response_model=str) +async def auth_check(token_data: TokenData = Depends(get_current_user)): + return token_data.sub From 5b7c32ebec6fd7a80f9d5366c2ef156003ae9244 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 24 Jan 2023 16:48:34 +0900 Subject: [PATCH 003/101] Updating axios to always send credentials --- website/src/lib/api.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts index df4bd399..f7a29f49 100644 --- a/website/src/lib/api.ts +++ b/website/src/lib/api.ts @@ -6,8 +6,11 @@ const headers = { "Content-Type": "application/json", }; +// Create Axios such that we always send credential cookies along with the +// request. This allows the Backend services to authenticate the user. const api = axios.create({ headers, + withCredentials: true, }); export const get = (url: string) => api.get(url).then((res) => res.data); From 3215a7bbf8d5efa06ec126a537c050794fde9bab Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Wed, 25 Jan 2023 05:03:43 +0000 Subject: [PATCH 004/101] [feature] add initial version of anthropic dataset --- model/reward/instructor/rank_datasets.py | 36 +++++++++++++++++++ model/reward/instructor/tests/test_dataset.py | 12 ++++++- model/reward/instructor/trainer.py | 2 +- model/reward/instructor/utils.py | 9 +++-- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index bd53eb02..5100d455 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -237,3 +237,39 @@ class HFDataset(Dataset): class GPTJSynthetic(HFDataset): def __init__(self) -> None: super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train") + + +class AnthropicRLHF(Dataset): + """ + The data are described in the paper: + Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback. + If you find the data useful, please cite the paper. + The data format is very simple -- each line of the jsonl files contains a pair of texts, + one "chosen" and one "rejected". + + """ + + def __init__(self, split="train", sep_token="") -> None: + super().__init__() + assert split in ("train", "test") + if sep_token is None: + sep_token = " . " + self.pairs = [] + # using prompt as our index will allows us + # to add additional generated prompt later + major_split = split if "train" == split else "test" + dataset = load_dataset("Anthropic/hh-rlhf")[major_split] + for data in dataset: + prompt, pos_postfix = data["chosen"].split("Assistant:", maxsplit=1) + pos_postfix = pos_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) + _, neg_postfix = data["rejected"].split("Assistant:", maxsplit=1) + neg_postfix = neg_postfix.replace("\n\nHuman: ", sep_token).replace("\n\nAssistant: ", sep_token) + self.pairs.append((prompt, (pos_postfix.strip(), neg_postfix.strip()))) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + context, pair = self.pairs[index] + + return context, [pair] diff --git a/model/reward/instructor/tests/test_dataset.py b/model/reward/instructor/tests/test_dataset.py index 832aace3..5cc9b7e8 100644 --- a/model/reward/instructor/tests/test_dataset.py +++ b/model/reward/instructor/tests/test_dataset.py @@ -1,5 +1,5 @@ from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality -from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT +from rank_datasets import AnthropicRLHF, DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -25,6 +25,16 @@ def test_webgpt(): print(batch["input_ids"].shape) +def test_anthropic_rlhf(): + + tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") + collate_fn = DataCollatorForPairRank(tokenizer, max_length=200) + dataset = AnthropicRLHF("test", sep_token=tokenizer.sep_token) + dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32) + for batch in dataloader: + print(batch["input_ids"].shape) + + def test_hf_summary_quality(): tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index 940c0708..abb05a42 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -155,7 +155,7 @@ if __name__ == "__main__": ) tokenizer = get_tokenizer(training_conf["tokenizer_name"]) - train, evals = get_datasets(training_conf["datasets"]) + train, evals = get_datasets(training_conf["datasets"], tokenizer) if "rankgen" in model_name: collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) else: diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index a6f3da4e..e8c7160b 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -99,8 +99,8 @@ def argument_parsing(parser): return params -def get_datasets(dataset_list: List[AnyStr]): - from rank_datasets import GPTJSynthetic, HFSummary, WebGPT +def get_datasets(dataset_list: List[AnyStr], tokenizer): + from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import ConcatDataset train_datasets, evals = [], {} @@ -121,6 +121,11 @@ def get_datasets(dataset_list: List[AnyStr]): train, eval = train_val_dataset(dataset, 0.1) train_datasets.append(train) evals["gptsynthetic"] = eval + elif "anthropic_rlhf" == dataset_name: + train = AnthropicRLHF("train", tokenizer.sep_token) + eval = AnthropicRLHF("test", tokenizer.sep_token) + train_datasets.append(train) + evals["anthropic_rlhf"] = eval train = ConcatDataset(train_datasets) return train, evals From 94d2ed820e7536d8fa7070d11729cce697fbb0b3 Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 25 Jan 2023 15:04:39 +0700 Subject: [PATCH 005/101] Refactor `OasstApiClient` --- website/src/components/UserTable.tsx | 3 +- website/src/lib/oasst_api_client.ts | 215 +++++++++------------------ website/src/types/Users.ts | 16 ++ 3 files changed, 87 insertions(+), 147 deletions(-) diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index 71d4fe44..57a96c95 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -4,8 +4,7 @@ import { Pencil } from "lucide-react"; import Link from "next/link"; import { memo, useState } from "react"; import { get } from "src/lib/api"; -import { FetchUsersResponse } from "src/lib/oasst_api_client"; -import type { User } from "src/types/Users"; +import type { FetchUsersResponse, User } from "src/types/Users"; import useSWR from "swr"; import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable"; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 7c48a288..0b3ecf1d 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,36 +1,20 @@ import type { Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; import type { AvailableTasks } from "src/types/Task"; -import type { BackendUser, BackendUserCore, User } from "src/types/Users"; +import type { BackendUser, BackendUserCore, FetchUsersParams, FetchUsersResponse } from "src/types/Users"; export class OasstError { message: string; errorCode: number; httpStatusCode: number; - constructor(message: string, errorCode: number, httpStatusCode?: number) { + constructor(message: string, errorCode: number, httpStatusCode: number) { this.message = message; this.errorCode = errorCode; this.httpStatusCode = httpStatusCode; } } -export type FetchUsersParams = { - limit: number; - cursor?: string; - direction: "forward" | "back"; - searchDisplayName?: string; - sortKey?: "username" | "display_name"; -}; - -export type FetchUsersResponse = { - items: T[]; - next?: string; - prev?: string; - sort_key: "username" | "display_name"; - order: "asc" | "desc"; -}; - export class OasstApiClient { oasstApiUrl: string; oasstApiKey: string; @@ -39,88 +23,6 @@ export class OasstApiClient { this.oasstApiUrl = oasstApiUrl; this.oasstApiKey = oasstApiKey; } - - private async post(path: string, body: any): Promise { - const resp = await fetch(`${this.oasstApiUrl}${path}`, { - method: "POST", - headers: { - "X-API-Key": this.oasstApiKey, - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }); - - if (resp.status === 204) { - return null; - } - - if (resp.status >= 300) { - const errorText = await resp.text(); - let error: any; - try { - error = JSON.parse(errorText); - } catch (e) { - throw new OasstError(errorText, 0, resp.status); - } - throw new OasstError(error.message ?? error, error.error_code, resp.status); - } - - return await resp.json(); - } - - private async put(path: string): Promise { - const resp = await fetch(`${this.oasstApiUrl}${path}`, { - method: "PUT", - headers: { - "X-API-Key": this.oasstApiKey, - }, - }); - - if (resp.status === 204) { - return null; - } - - if (resp.status >= 300) { - const errorText = await resp.text(); - let error: any; - try { - error = JSON.parse(errorText); - } catch (e) { - throw new OasstError(errorText, 0, resp.status); - } - throw new OasstError(error.message ?? error, error.error_code, resp.status); - } - - return await resp.json(); - } - - private async get(path: string): Promise { - const resp = await fetch(`${this.oasstApiUrl}${path}`, { - method: "GET", - headers: { - "X-API-Key": this.oasstApiKey, - "Content-Type": "application/json", - }, - }); - - if (resp.status === 204) { - return null; - } - - if (resp.status >= 300) { - const errorText = await resp.text(); - let error: any; - try { - error = JSON.parse(errorText); - } catch (e) { - throw new OasstError(errorText, 0, resp.status); - } - throw new OasstError(error.message ?? error, error.error_code, resp.status); - } - - return await resp.json(); - } - // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. // This is a raw Json type, so we can't use it to strongly type the task. @@ -133,13 +35,13 @@ export class OasstApiClient { } async ackTask(taskId: string, messageId: string): Promise { - return this.post(`/api/v1/tasks/${taskId}/ack`, { + await this.post(`/api/v1/tasks/${taskId}/ack`, { message_id: messageId, }); } async nackTask(taskId: string, reason: string): Promise { - return this.post(`/api/v1/tasks/${taskId}/nack`, { + await this.post(`/api/v1/tasks/${taskId}/nack`, { reason, }); } @@ -170,8 +72,8 @@ export class OasstApiClient { /** * Returns the tasks availability information for given `user`. */ - async fetch_tasks_availability(user: object): Promise { - return this.post("/api/v1/tasks/availability", user); + async fetch_tasks_availability(user: object): Promise { + return this.post("/api/v1/tasks/availability", user); } /** @@ -191,18 +93,12 @@ export class OasstApiClient { /** * Returns the `BackendUser` associated with `user_id` */ - async fetch_user(user_id: string): Promise { + async fetch_user(user_id: string): Promise { return this.get(`/api/v1/users/${user_id}`); } /** * Returns the set of `BackendUser`s stored by the backend. - * - * @param {number} max_count - The maximum number of users to fetch. - * @param {string} cursor - The user's `display_name` to use when paginating. - * @param {boolean} isForward - If true and `cursor` is not empty, pages - * forward. If false and `cursor` is not empty, pages backwards. - * @returns {Promise} A Promise that returns an array of `BackendUser` objects. */ async fetch_users({ direction, @@ -210,46 +106,28 @@ export class OasstApiClient { cursor, searchDisplayName, sortKey = "display_name", - }: FetchUsersParams): Promise { - const params = new URLSearchParams({ + }: FetchUsersParams): Promise { + return this.get(`/api/v1/users/cursor`, { search_text: searchDisplayName, sort_key: sortKey, - max_count: limit.toString(), + max_count: limit, + after: direction === "forward" ? cursor : undefined, + before: direction === "back" ? cursor : undefined, }); - - // The backend API uses different query parameters depending on the - // pagination direction but they both take the same cursor value. - // Depending on direction, pick the right query param. - if (cursor !== "") { - params.append(direction === "forward" ? "after" : "before", cursor); - } - const BASE_URL = `/api/v1/users/cursor`; - const url = `${BASE_URL}/?${params.toString()}`; - return this.get(url); } - // async fetch_user_by_display_name(name: string): Promise { - // const params = new URLSearchParams({ - // search_text: name, - // }); - - // const endpoint = `/api/v1/frontend_users/by_display_name`; - - // return this.get(`${endpoint}?${params.toString()}`); - // } - /** * Returns the `Message`s associated with `user_id` in the backend. */ - async fetch_user_messages(user_id: string): Promise { - return this.get(`/api/v1/users/${user_id}/messages`); + async fetch_user_messages(user_id: string): Promise { + return this.get(`/api/v1/users/${user_id}/messages`); } /** * Updates the backend's knowledge about the `user_id`. */ - async set_user_status(user_id: string, is_enabled: boolean, notes): Promise { - return this.put(`/api/v1/users/users/${user_id}?enabled=${is_enabled}¬es=${notes}`); + async set_user_status(user_id: string, is_enabled: boolean, notes: string): Promise { + await this.put(`/api/v1/users/users/${user_id}?enabled=${is_enabled}¬es=${notes}`); } /** @@ -265,18 +143,65 @@ export class OasstApiClient { async fetch_leaderboard( time_frame: LeaderboardTimeFrame, { limit = 20 }: { limit?: number } - ): Promise { - const params = new URLSearchParams({ - limit: limit.toString(), - }); - return this.get(`/api/v1/leaderboards/${time_frame}?${params.toString()}`); + ): Promise { + return this.get(`/api/v1/leaderboards/${time_frame}`, { limit }); } /** * Returns the counts of all tasks (some might be zero) */ - async fetch_available_tasks(user: BackendUserCore, lang: string): Promise { - return this.post(`/api/v1/tasks/availability?lang=${lang}`, user); + async fetch_available_tasks(user: BackendUserCore, lang: string): Promise { + return this.post(`/api/v1/tasks/availability?lang=${lang}`, user); + } + + private async post(path: string, body: unknown) { + return this.request("POST", path, { + body: JSON.stringify(body), + }); + } + + private async put(path: string) { + return this.request("PUT", path); + } + + private async get(path: string, query: Record = {}) { + const filteredQuery = Object.fromEntries( + Object.entries(query).filter(([, value]) => value !== undefined) + ) as Record; + + const params = new URLSearchParams(filteredQuery).toString(); + + return this.request("GET", `${path}${query ? `?${params}` : ""}`); + } + + private async request(method: "GET" | "POST" | "PUT", path: string, init?: RequestInit): Promise { + const resp = await fetch(`${this.oasstApiUrl}${path}`, { + method, + ...init, + headers: { + "X-API-Key": this.oasstApiKey, + "Content-Type": "application/json", + ...init?.headers, + }, + }); + + if (resp.status === 204) { + return null; + } + + if (resp.status >= 300) { + const errorText = await resp.text(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let error: any; + try { + error = JSON.parse(errorText); + } catch (e) { + throw new OasstError(errorText, 0, resp.status); + } + throw new OasstError(error.message ?? error, error.error_code, resp.status); + } + + return await resp.json(); } } diff --git a/website/src/types/Users.ts b/website/src/types/Users.ts index 39d2a663..52240b5f 100644 --- a/website/src/types/Users.ts +++ b/website/src/types/Users.ts @@ -51,3 +51,19 @@ export interface User extends BackendUser { */ role: string; } + +export type FetchUsersParams = { + limit: number; + cursor?: string; + direction: "forward" | "back"; + searchDisplayName?: string; + sortKey?: "username" | "display_name"; +}; + +export type FetchUsersResponse = { + items: T[]; + next?: string; + prev?: string; + sort_key: "username" | "display_name"; + order: "asc" | "desc"; +}; From 31630be319afc49a7db07ec4d176eefd5bfd4696 Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 25 Jan 2023 15:48:53 +0700 Subject: [PATCH 006/101] fix test & build --- website/{src => }/.prettierignore | 0 website/{src => }/.prettierrc.json | 0 .../contract/oasst_api_contract_tests.cy.ts | 17 +++++------------ website/src/lib/api.ts | 2 +- website/src/pages/api/admin/users.ts | 3 ++- 5 files changed, 8 insertions(+), 14 deletions(-) rename website/{src => }/.prettierignore (100%) rename website/{src => }/.prettierrc.json (100%) diff --git a/website/src/.prettierignore b/website/.prettierignore similarity index 100% rename from website/src/.prettierignore rename to website/.prettierignore diff --git a/website/src/.prettierrc.json b/website/.prettierrc.json similarity index 100% rename from website/src/.prettierrc.json rename to website/.prettierrc.json diff --git a/website/cypress/contract/oasst_api_contract_tests.cy.ts b/website/cypress/contract/oasst_api_contract_tests.cy.ts index d2ffeba3..0f217f85 100644 --- a/website/cypress/contract/oasst_api_contract_tests.cy.ts +++ b/website/cypress/contract/oasst_api_contract_tests.cy.ts @@ -12,25 +12,18 @@ describe("Contract test for Oasst API", function () { } as BackendUserCore; it("can fetch a task", async () => { - expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null; + expect(await oasstApiClient.fetchTask("random", testUser, "en")).to.be.not.null; }); it("can ack a task", async () => { - const task = await oasstApiClient.fetchTask("random", testUser); - expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null; + const task = await oasstApiClient.fetchTask("random", testUser, "en"); + expect(await oasstApiClient.ackTask(task.id, "321")).to.be.undefined; }); it("can record a taskInteraction", async () => { - const task = await oasstApiClient.fetchTask("random", testUser); + const task = await oasstApiClient.fetchTask("random", testUser, "en"); expect( - await oasstApiClient.interactTask( - "text_reply_to_message", - task.id, - "321", - "1", - { text: "Test" }, - testUser - ) + await oasstApiClient.interactTask("text_reply_to_message", task.id, "321", "1", { text: "Test" }, testUser, "en") ).to.be.not.null; }); diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts index df4bd399..d61016d2 100644 --- a/website/src/lib/api.ts +++ b/website/src/lib/api.ts @@ -17,7 +17,7 @@ export const post = (url: string, { arg: data }) => api.post(url, data).then((re api.interceptors.response.use( (response) => response, (error) => { - throw new OasstError(error.message ?? error, error.error_code); + throw new OasstError(error.message ?? error, error.error_code, error?.response?.status || -1); } ); diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts index 57944cff..d10c91b0 100644 --- a/website/src/pages/api/admin/users.ts +++ b/website/src/pages/api/admin/users.ts @@ -1,6 +1,7 @@ import { withRole } from "src/lib/auth"; -import { FetchUsersParams, oasstApiClient } from "src/lib/oasst_api_client"; +import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { FetchUsersParams } from "src/types/Users"; /** * The number of users to fetch in a single request. Could later be a query parameter. From 97cd57a300cf8abfa149764eaa9f97e64e7dda2d Mon Sep 17 00:00:00 2001 From: notmd Date: Wed, 25 Jan 2023 16:02:20 +0700 Subject: [PATCH 007/101] fix test --- website/.prettierignore | 3 + website/README.md | 167 +++++++----------- website/cypress/README.md | 98 +++++----- website/cypress/components/Container.cy.tsx | 5 +- .../contract/oasst_api_contract_tests.cy.ts | 2 +- website/cypress/e2e/tasks/random.cy.ts | 4 +- website/cypress/support/commands.ts | 15 +- website/src/lib/oasst_api_client.ts | 8 +- website/styles/Home.module.css | 52 +++--- website/styles/Theme/index.tsx | 6 +- 10 files changed, 147 insertions(+), 213 deletions(-) diff --git a/website/.prettierignore b/website/.prettierignore index e69de29b..0be2a485 100644 --- a/website/.prettierignore +++ b/website/.prettierignore @@ -0,0 +1,3 @@ +.eslintrc.json +tailwind.config.js +.storybook/* diff --git a/website/README.md b/website/README.md index a30f2754..09b40e49 100644 --- a/website/README.md +++ b/website/README.md @@ -2,8 +2,7 @@ ## Purpose -This provides a comprehensive webapp interface for LAION's Open Assistant -project. Initially it will support: +This provides a comprehensive webapp interface for LAION's Open Assistant project. Initially it will support: 1. User registration using either Discord or Email. 1. Adding responses to incomplete Open Assistant tasks. @@ -11,8 +10,7 @@ project. Initially it will support: 1. Viewing an activity leaderboard. 1. Tracking community wide updates. -This interface compliments the Discord bot and will give access to the same -underlying tasks. +This interface compliments the Discord bot and will give access to the same underlying tasks. ## Contributing @@ -22,67 +20,54 @@ This website is built using: 1. [npm](https://www.npmjs.com/): The node package manager for building. 1. [React](https://reactjs.org/): The core frontend framework. -1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline - development. -1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific - [Postgres](https://www.postgresql.org/) database. -1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to - ensure we handle accounts with best practices. -1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for - styling any component. -1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI - components that generally look pretty good. +1. [Next.js](https://nextjs.org/): A React scaffolding framework to streamline development. +1. [Prisma](https://www.prisma.io/): An ORM to interact with a web specific [Postgres](https://www.postgresql.org/) + database. +1. [NextAuth.js](https://next-auth.js.org/): A user authentication framework to ensure we handle accounts with best + practices. +1. [TailwindCSS](https://tailwindcss.com/): A general purpose framework for styling any component. +1. [Chakra-UI](https://chakra-ui.com/): A wide collection of pre-built UI components that generally look pretty good. ### Set up your environment -To contribute to the website, make sure you have the following setup and -installed: +To contribute to the website, make sure you have the following setup and installed: -1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy - to ensure you have the right NodeJS version installed. Once installed, run - `nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS +1. [NVM](https://github.com/nvm-sh/nvm): The Node Version Manager makes it easy to ensure you have the right NodeJS + version installed. Once installed, run `nvm use 16` to use Node 16.x. The website is known to be stable with NodeJS version 16.x. This will install both Node and NPM. -1. [Docker](https://www.docker.com/): We use docker to simplify running - dependent services. +1. [Docker](https://www.docker.com/): We use docker to simplify running dependent services. ### Getting everything up and running If you're doing active development we suggest the following workflow: 1. In one tab, navigate to the project root. -1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can - optionally include `-d` to detach and later track the logs if desired. +1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can optionally include `-d` to detach and + later track the logs if desired. 1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`. 1. Run `npm ci` -1. Run `npx prisma db push` (This is also needed when you restart the docker - stack from scratch). -1. Run `npm run dev`. Now the website is up and running locally at - `http://localhost:3000`. -1. To create an account, login via the user using email authentication and - navigate to `http://localhost:1080`. Check the email listed and click the - log in link. You're now logged in and authenticated. +1. Run `npx prisma db push` (This is also needed when you restart the docker stack from scratch). +1. Run `npm run dev`. Now the website is up and running locally at `http://localhost:3000`. +1. To create an account, login via the user using email authentication and navigate to `http://localhost:1080`. Check + the email listed and click the log in link. You're now logged in and authenticated. ### Using debug user credentials -You can use the debug credentials provider to log in without fancy emails or -OAuth. +You can use the debug credentials provider to log in without fancy emails or OAuth. -1. This feature is automatically on in development mode, i.e. when you run - `npm run dev`. In case you want to do the same with a production build (for - example, the docker image), then run the website with environment variable +1. This feature is automatically on in development mode, i.e. when you run `npm run dev`. In case you want to do the + same with a production build (for example, the docker image), then run the website with environment variable `DEBUG_LOGIN=true`. 1. Use the `Login` button in the top right to go to the login page. -1. You should see a section for debug credentials. Enter any username you wish, - you will be logged in as that user. +1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user. ### Using Storybook -To develop components using [Storybook](https://storybook.js.org/) run -`npm run storybook`. Then navigate to in your browser to -`http://localhost:6006`. +To develop components using [Storybook](https://storybook.js.org/) run `npm run storybook`. Then navigate to in your +browser to `http://localhost:6006`. -To create a new story create a file named `[componentName].stories.js`. An -example how such a story could look like, see `Header.stories.jsx`. +To create a new story create a file named `[componentName].stories.js`. An example how such a story could look like, see +`Header.stories.jsx`. ## Code Layout @@ -90,12 +75,10 @@ example how such a story could look like, see `Header.stories.jsx`. All react code is under `src/` with a few sub directories: -1. `pages/`: All pages a user could navigate too and API URLs which are under - `pages/api/`. -1. `components/`: All re-usable React components. If something gets used twice - we should create a component and put it here. -1. `lib/`: A generic place to store library files that are used anywhere. This - doesn't have much structure yet. +1. `pages/`: All pages a user could navigate too and API URLs which are under `pages/api/`. +1. `components/`: All re-usable React components. If something gets used twice we should create a component and put it + here. +1. `lib/`: A generic place to store library files that are used anywhere. This doesn't have much structure yet. NOTE: `styles/` can be ignored for now. @@ -113,25 +96,20 @@ We're not really using CSS styles. `styles/` can be ignored. ## Testing the UI -Cypress is used for end-to-end (e2e) and component testing and is configured in -`./cypress.config.ts`. The `./cypress` folder is used for supporting -configuration files etc. +Cypress is used for end-to-end (e2e) and component testing and is configured in `./cypress.config.ts`. The `./cypress` +folder is used for supporting configuration files etc. - Store e2e tests in the `./cypress/e2e` folder. -- Store component tests adjacent to the component being tested. If you want to - wriite a test for `./src/components/Layout.tsx` then store the test file at - `./src/components/Layout.cy.tsx`. +- Store component tests adjacent to the component being tested. If you want to wriite a test for + `./src/components/Layout.tsx` then store the test file at `./src/components/Layout.cy.tsx`. A few npm scripts are available for convenience: -- `npm run cypress`: Useful for development, it opens Cypress and allows you to - explore, run and debug tests. It assumes you have the NextJS site running at - `localhost:3000`. -- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before - sending a PR or to run in CI pipelines. -- `npm run cypress:image-baseline`: If you have tests failing because of visual - changes that was expected, this command will update the baseline images stored - in `./cypress-visual-screenshots/baseline` with those from the adjacent +- `npm run cypress`: Useful for development, it opens Cypress and allows you to explore, run and debug tests. It assumes + you have the NextJS site running at `localhost:3000`. +- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before sending a PR or to run in CI pipelines. +- `npm run cypress:image-baseline`: If you have tests failing because of visual changes that was expected, this command + will update the baseline images stored in `./cypress-visual-screenshots/baseline` with those from the adjacent comparison folder. More can be found in the [docs of `uktrade/cypress-image-diff`](https://github.com/uktrade/cypress-image-diff/blob/main/docs/CLI.md#update-all-baseline-images-for-failing-tests). @@ -141,10 +119,9 @@ Read more in the [./cypress README](cypress/). Jest and React Testing Library are used for unit testing JS/TS/TSX code. -- Store unit test files adjacent to the file being tested and have the filename - end with `.test.ts` for non-React code or `.test.tsx` for React code. -- `npm run jest`: automatically runs tests and watches for any relevant changes - to rerun tests. +- Store unit test files adjacent to the file being tested and have the filename end with `.test.ts` for non-React code + or `.test.tsx` for React code. +- `npm run jest`: automatically runs tests and watches for any relevant changes to rerun tests. Read more in the [./src/README.md](src/README.md). @@ -152,30 +129,25 @@ Read more in the [./src/README.md](src/README.md). When writing code for the website, we have a few best practices: -1. When importing packages import external dependencies first then local - dependencies. Order them alphabetically according to the package name. -1. When trying to implement something new, check if - [Chakra-UI](https://chakra-ui.com/) has components that are close enough to - your need. For example Sliders, Radio Buttons, Progress indicators, etc. - They have a lot and we can save time by re-using what they have and tweaking - the style as needed. -1. Format everything with [Prettier](https://prettier.io/). This is done by - default with pre-submits. We currently don't have any custom settings. -1. Define functional React components (with types for all properties when - feasible). +1. When importing packages import external dependencies first then local dependencies. Order them alphabetically + according to the package name. +1. When trying to implement something new, check if [Chakra-UI](https://chakra-ui.com/) has components that are close + enough to your need. For example Sliders, Radio Buttons, Progress indicators, etc. They have a lot and we can save + time by re-using what they have and tweaking the style as needed. +1. Format everything with [Prettier](https://prettier.io/). This is done by default with pre-submits. We currently + don't have any custom settings. +1. Define functional React components (with types for all properties when feasible). ### Developing New Features -When working on new features or making significant changes that can't be done -within a single Pull Request, we ask that you make use of Feature Flags. +When working on new features or making significant changes that can't be done within a single Pull Request, we ask that +you make use of Feature Flags. -We've set up -[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to -make this easier. To get started: +We've set up [`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to make this easier. To get +started: -1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you - can copy as an example. Be sure to `isActive` to true when testing your - features but false when submitting your PR. +1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you can copy as an example. Be sure to + `isActive` to true when testing your features but false when submitting your PR. 1. Use your flag wherever you add a new UI element. This can be done with: ```js @@ -188,29 +160,24 @@ import { Flags } from "react-feature-flags"; You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`. -1. Once you've finished building out the feature and it is ready for everyone - to use, it's safe to remove the `Flag` wrappers around your component and - the entry in `flags.ts`. +1. Once you've finished building out the feature and it is ready for everyone to use, it's safe to remove the `Flag` + wrappers around your component and the entry in `flags.ts`. ### URL Paths -To use stable and consistent URL paths, we recommend the following strategy for -new tasks: +To use stable and consistent URL paths, we recommend the following strategy for new tasks: -1. For any task that involves writing a free-form response, put the page under - `website/src/pages/create` with a page name matching the task type, such as - `initial_prompt.tsx`. -1. For any task that evaluates, rates, or ranks content, put the page under - `website/src/pages/evaluate` with a page name matching the task type such as - `rank_initial_prompts.tsx`. +1. For any task that involves writing a free-form response, put the page under `website/src/pages/create` with a page + name matching the task type, such as `initial_prompt.tsx`. +1. For any task that evaluates, rates, or ranks content, put the page under `website/src/pages/evaluate` with a page + name matching the task type such as `rank_initial_prompts.tsx`. -With this we'll be able to ensure these contribution pages are hidden from -logged out users but accessible to logged in users. +With this we'll be able to ensure these contribution pages are hidden from logged out users but accessible to logged in +users. ## Learn More To learn more about Next.js, take a look at the following resources: -- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js - features and API. +- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. diff --git a/website/cypress/README.md b/website/cypress/README.md index 4750cbf6..d6a2b383 100644 --- a/website/cypress/README.md +++ b/website/cypress/README.md @@ -1,24 +1,19 @@ # Component and e2e testing with Cypress -[Cypress](https://www.cypress.io/) is used for both component- and end-to-end -testing. Below there's a few examples for the context of this site. To learn -more, the -[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) -has it all. +[Cypress](https://www.cypress.io/) is used for both component- and end-to-end testing. Below there's a few examples for +the context of this site. To learn more, the +[Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) has it all. -Don't get scared by the commercial offerings they offer. Their core is open -source, the cloud offering is not necesarry at all and can be replaced by CI -tooling and [community efforts](https://sorry-cypress.dev/). +Don't get scared by the commercial offerings they offer. Their core is open source, the cloud offering is not necesarry +at all and can be replaced by CI tooling and [community efforts](https://sorry-cypress.dev/). # Component testing -To write a new component test, you either create a new `.tsx` adjacent to the -component you want to test or you can use the guide presented yo you when -running `npm run cypress` which allows you to easily create the skeleton test -for an existing component. +To write a new component test, you either create a new `.tsx` adjacent to the component you want to test or you can use +the guide presented yo you when running `npm run cypress` which allows you to easily create the skeleton test for an +existing component. -If you have a `Button.tsx` component, create a file next to it called -`Button.cy.tsx` which could look like this: +If you have a `Button.tsx` component, create a file next to it called `Button.cy.tsx` which could look like this: ```typescript import React from "react"; @@ -35,28 +30,24 @@ describe(" + + + ); +}; From 1eb3f05c44d1aea680b9cc936a459112159f49ae Mon Sep 17 00:00:00 2001 From: kayjay Date: Thu, 26 Jan 2023 01:49:03 -0800 Subject: [PATCH 017/101] PR: Create notebook to convert r/changemyview data (#839) * (#737) Create notebook to convert r/changemyview data into cleaner format --- .../changemyview-builder/README.md | 37 ++ .../changemyview-builder/data_processor.ipynb | 577 ++++++++++++++++++ 2 files changed, 614 insertions(+) create mode 100644 notebooks/data-augmentation/changemyview-builder/README.md create mode 100644 notebooks/data-augmentation/changemyview-builder/data_processor.ipynb diff --git a/notebooks/data-augmentation/changemyview-builder/README.md b/notebooks/data-augmentation/changemyview-builder/README.md new file mode 100644 index 00000000..62ab9b2e --- /dev/null +++ b/notebooks/data-augmentation/changemyview-builder/README.md @@ -0,0 +1,37 @@ +# README + +## Introduction + +This program converts data obtained from the subreddit r/changemyview into a cleaner format for further data processing. The data is not clean enough to be used directly in a model yet, and additional preprocessing is required. + +## Data Format + +The cleaned data is stored in an Apache Parquet file with the following columns: + +| Column Name | Description | Data Type | +|-------------|------------------------------------------------------------------------|----------------| +| INSTRUCTION | Post title + body text | String | +| RESPONSE | Body text of comments attempting to change OP's mind of `INSTRUCTION`. | List\ | +| SOURCE | Permalink to the reddit post | String | +| METADATA | Metadata related to `RESPONSE`. | Dict\ | + +### Metadata +Currently, metadata is only broken into one category: +- `detoxify_labels`- A Dictionary of values outputted by the [Unitaryai Detoxifier](https://github.com/unitaryai/detoxify) model, fitted to every comment under any given post. + +## Usage + +To use the program, follow these instructions: + +1. **Clone the repository** - `git clone https://github.com/LAION-AI/Open-Assistant.git` +2. **Navigate to the project directory** - `cd notebooks/data-augmentation/changemyview-builder` +3. **Open the Jupyter Notebook** - `jupyter notebook data_processor.ipynb` +4. **Run the program** - Go through the notebook and run the cells + +## Contributing + +If you would like to contribute to this project, please fork the repository and submit a pull request with your changes. + +## License + +This project is licensed under the Apache-2.0 License - see the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb b/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb new file mode 100644 index 00000000..03155d40 --- /dev/null +++ b/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb @@ -0,0 +1,577 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# r/ChangeMyView data converter\n", + "Converts subreddit data into readable format for ML training\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/changemyview-builder/data_processor.ipynb)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 65, + "outputs": [], + "source": [ + "### REMEMBER: setup the .env before running this code!\n", + "\n", + "\"\"\"CONSTANTS\"\"\"\n", + "\n", + "# Set the head number to the amount of entries you want to load in minus one\n", + "ENTRIES_COUNT = 10\n", + "\n", + "# Set the threshold for toxic comments to be removed\n", + "TOXIC_THRESHOLD = 0.95" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 66, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pandas in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (1.4.4)\r\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2.8.2)\r\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (2022.1)\r\n", + "Requirement already satisfied: numpy>=1.18.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pandas) (1.21.5)\r\n", + "Requirement already satisfied: six>=1.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\r\n", + "Requirement already satisfied: praw in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (7.6.1)\r\n", + "Requirement already satisfied: websocket-client>=0.54.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.58.0)\r\n", + "Requirement already satisfied: update-checker>=0.18 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (0.18.0)\r\n", + "Requirement already satisfied: prawcore<3,>=2.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from praw) (2.3.0)\r\n", + "Requirement already satisfied: requests<3.0,>=2.6.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from prawcore<3,>=2.1->praw) (2.28.1)\r\n", + "Requirement already satisfied: six in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from websocket-client>=0.54.0->praw) (1.16.0)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (3.3)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2022.9.24)\r\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (2.0.4)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests<3.0,>=2.6.0->prawcore<3,>=2.1->praw) (1.26.11)\r\n", + "Requirement already satisfied: python-dotenv in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.21.0)\r\n", + "Requirement already satisfied: pyarrow in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (10.0.1)\r\n", + "Requirement already satisfied: numpy>=1.16.6 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from pyarrow) (1.21.5)\r\n", + "Requirement already satisfied: detoxify in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (0.5.1)\r\n", + "Requirement already satisfied: transformers==4.22.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (4.22.1)\r\n", + "Requirement already satisfied: torch>=1.7.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (1.13.1)\r\n", + "Requirement already satisfied: sentencepiece>=0.1.94 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from detoxify) (0.1.97)\r\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.9.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.11.1)\r\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2022.7.9)\r\n", + "Requirement already satisfied: pyyaml>=5.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (6.0)\r\n", + "Requirement already satisfied: tqdm>=4.27 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (4.64.1)\r\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (0.12.1)\r\n", + "Requirement already satisfied: filelock in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (3.6.0)\r\n", + "Requirement already satisfied: packaging>=20.0 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (21.3)\r\n", + "Requirement already satisfied: numpy>=1.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (1.21.5)\r\n", + "Requirement already satisfied: requests in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from transformers==4.22.1->detoxify) (2.28.1)\r\n", + "Requirement already satisfied: typing-extensions in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from torch>=1.7.0->detoxify) (4.3.0)\r\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers==4.22.1->detoxify) (3.0.9)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (3.3)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (1.26.11)\r\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2.0.4)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers==4.22.1->detoxify) (2022.9.24)\r\n", + "Requirement already satisfied: tqdm in /Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages (4.64.1)\r\n" + ] + } + ], + "source": [ + "# Install any dependencies\n", + "!pip install pandas\n", + "!pip install praw\n", + "!pip install python-dotenv\n", + "!pip install pyarrow\n", + "!pip install detoxify\n", + "!pip install tqdm" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import praw\n", + "import os\n", + "from os.path import join, dirname\n", + "from dotenv import main\n", + "\n", + "# Make sure you create a .env file and fill in all the necessary information in the same folder as this script!\n", + "main.load_dotenv(join(dirname(os.path.realpath('__file__')), '.env'))\n", + "\n", + "reddit = praw.Reddit(\n", + " client_id=os.environ.get(\"CLIENT_ID\"),\n", + " client_secret=os.environ.get(\"CLIENT_SECRET\"),\n", + " user_agent=\"CMV_Scraper\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "outputs": [], + "source": [ + "# load the data\n", + "import tarfile\n", + "import os.path\n", + "import json\n", + "import re\n", + "from bz2 import BZ2File\n", + "from urllib import request\n", + "from io import BytesIO\n", + "\n", + "import numpy as np\n", + "\n", + "\n", + "fname = \"cmv.tar.bz2\"\n", + "url = \"https://chenhaot.com/data/cmv/\" + fname\n", + "\n", + "# download if not exists\n", + "if not os.path.isfile(fname):\n", + " f = BytesIO()\n", + " with request.urlopen(url) as resp, open(fname, 'wb') as f_disk:\n", + " data = resp.read()\n", + " f_disk.write(data) # save to disk too\n", + " f.write(data)\n", + " f.seek(0)\n", + "else:\n", + " f = open(fname, 'rb')\n", + "\n", + "\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 69, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kayjaymac/opt/anaconda3/lib/python3.9/bz2.py:124: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n", + " self._buffer = None\n", + "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + } + ], + "source": [ + "#tar = tarfile.open(fileobj=f, mode=\"r:bz2\")\n", + "tar = tarfile.open(fileobj=f, mode=\"r\")\n", + "\n", + "# Extract the file we are interested in\n", + "\n", + "train_fname = \"op_task/train_op_data.jsonlist.bz2\"\n", + "test_fname = \"op_task/heldout_op_data.jsonlist.bz2\"\n", + "\n", + "train_bzlist = tar.extractfile(train_fname)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 70, + "outputs": [], + "source": [ + "# Deserialize the JSON list\n", + "original_posts_train = [\n", + " json.loads(line.decode('utf-8'))\n", + " for line in BZ2File(train_bzlist)\n", + "]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 71, + "outputs": [ + { + "data": { + "text/plain": "[{'title': \"CMV: I shouldn't get a job in this economic climate because it'll be automated anyway; I should just wait for a post-scarcity utopia.\",\n 'delta_label': False,\n 'name': 't3_2rpsl8',\n 'selftext': \"I think the world is automating fast enough that a utopia will arise where no one will have to work anymore. Within the next 2 decades or so, having a job won't mean much, and most people will be artists and scientists. \\n\\nMy parents let me live with them, so I can just wait until the utopia happens.\\n\\nCMV.\"}]" + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "original_posts_train[:1]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 72, + "outputs": [], + "source": [ + "# Load the jsonlist file into a dataframe\n", + "#df = pd.read_json(original_posts_train, orient='list', lines=True)\n", + "df = pd.DataFrame(original_posts_train)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 73, + "outputs": [], + "source": [ + "# Function to check if the posts still exists on reddit\n", + "def try_get_post(post_id):\n", + " try:\n", + " submission = reddit.submission(id=post_id)\n", + " submission.name\n", + " return True\n", + " except Exception as e:\n", + " return False" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 74, + "outputs": [], + "source": [ + "# Set up the detoxifier model:\n", + "from detoxify import Detoxify" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "\n", + "# Removes > sign and the template message at the end of a message\n", + "def cleanup_body_text(cmv_post):\n", + " lines = [line for line in cmv_post.splitlines()\n", + " if not line.lstrip().startswith(\">\")\n", + " and not line.lstrip().startswith(\"____\")\n", + " and not line.lstrip().startswith(\"So go forth and CMV, noble redditors!\")\n", + " and \"edit\" not in \" \".join(line.lower().split()[:2])\n", + " ]\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "\n", + "\n", + "# Create the function that will be handling all the data gathering\n", + "def get_top_comment_and_clean_data(post_id):\n", + " #print(post_id.lstrip(\"t3_\"))\n", + " last_author = \"\"\n", + " # Grab the post\n", + " submission = reddit.submission(id=post_id.lstrip(\"t3_\"))\n", + " #print(submission.title)\n", + "\n", + " # Grab the highest rated comment on root layer\n", + " submission.submission_type = 'best'\n", + " submission.comments.replace_more(limit=0)\n", + " replies = list(submission.comments)[0].replies.list()\n", + "\n", + " # Just some variables\n", + " pros = []\n", + "\n", + " # If the post author doesn't exist this submission was deleted (submission.deleted doesn't work)\n", + " if type(submission.author) == type(None):\n", + " last_author = \"[deleted]\"\n", + " else:\n", + " last_author = submission.author.name\n", + "\n", + " is_pro_argument = False\n", + "\n", + " for comment in replies:\n", + "\n", + " # If redditor object doesn't exist, the account is invalid/deleted\n", + " if type(comment.author) != type(None):\n", + " author = comment.author.name\n", + " else:\n", + " author = \"[deleted]\"\n", + "\n", + " # Assume that whenever the user changes, they are countering the previous person\n", + " if author != last_author:\n", + " is_pro_argument = !is_pro_argument\n", + "\n", + " if author == \"[deleted]\" or author==\"DeltaBot\":\n", + " #print(\"Skipping comment...\")\n", + " continue\n", + "\n", + " # Remove meta and duplicate comments\n", + " comment.body = \" \".join([line for line in comment.body.splitlines()\n", + " if not re.search(r\"(?i)(Change\\smy\\sview|CMV)\", line)\n", + " and line not in pros # Why doesn't this line work\n", + " ])\n", + "\n", + " # Sometimes for some reason duplicate entries exist\n", + " # Also remove automated message with \"Δ\" in it\n", + "\n", + " if comment.body in pros:\n", + " #print(\"Skipping duplicate entry\")\n", + " continue\n", + "\n", + " #print(\"\\t\\t>>\\t\",comment.body)\n", + "\n", + " # Remove toxic comments\n", + " if Detoxify(\"multilingual\").predict(comment.body)[\"toxicity\"] > TOXIC_THRESHOLD:\n", + " #print(\"Identified toxic comment, ignoring...\")\n", + " comment.body = \"\"\n", + "\n", + " # Add to the respective argument type \n", + " if is_pro_argument:\n", + " pros.append(comment.body)\n", + " \n", + " last_author = comment.author.name\n", + " \n", + " # Pros = arguments for the Title of this post\n", + " # Cons = arguments against the title of this post\n", + "\n", + " pros.append(comment.body)\n", + " return pros" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading in 10 posts\n" + ] + } + ], + "source": [ + "print(f\"Loading in {ENTRIES_COUNT} posts\")\n", + "dataset = df.head(ENTRIES_COUNT)\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 77, + "outputs": [], + "source": [ + "# the name column does some weird sh** because dataframes already have a name property, so migrate to a different column name\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "dataset[\"post_id\"] = dataset[\"name\"]\n", + "warnings.filterwarnings('default')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading in data... This will take a while.\n" + ] + }, + { + "data": { + "text/plain": " 0%| | 0/10 [00:00:29: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + "/Users/kayjaymac/opt/anaconda3/lib/python3.9/site-packages/torch/serialization.py:997: ResourceWarning: unclosed file <_io.BufferedReader name='cmv.tar.bz2'>\n", + " storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()\n", + "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 7min 49s, sys: 2min 29s, total: 10min 19s\n", + "Wall time: 8min 45s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "from tqdm.auto import tqdm\n", + "# Reset variables for if we run this multiple times\n", + "all_pros = []\n", + "all_names = []\n", + "all_titles = []\n", + "all_sources = []\n", + "\n", + "print(\"Loading in data... This will take a while.\")\n", + "\n", + "for i in tqdm(range(dataset.shape[0])):\n", + "\n", + " post = dataset.iloc[i]\n", + " modified_title = post.title.replace('CMV', \"Change my mind\")\n", + " #print(f\"\\n Loading entry {i+1}/{dataset.shape[0]}:\\n\\t\\\"{modified_title}\\\"\")\n", + "\n", + " if type(post) == type(None):\n", + " continue\n", + "\n", + " assert(post.post_id != i)\n", + "\n", + " pros = get_top_comment_and_clean_data(post.post_id)\n", + "\n", + " if post.title == \"[deleted]\":\n", + " continue\n", + "\n", + " pros = \" \".join([*set(pros)])\n", + " pros = pros.replace(\"[deleted]\",\"\")\n", + "\n", + " post.selftext = cleanup_body_text(post.selftext)\n", + " all_titles.append(modified_title + \" \" + post.selftext)\n", + " all_pros.append(pros)\n", + " all_names.append(post.name)\n", + " all_sources.append(f\"https://reddit.com/r/changemyview/comments/{post.post_id}\")\n", + " #print(post.title)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "outputs": [ + { + "data": { + "text/plain": "'it\\'s already been signed. They even claim to be adhering to it, though they\\'ve been found to be violating it before. There is no such thing as \"de facto acceptance of Israel\\'s nuclear program.\" the Non-Proliferation Treaty is only binding for signatory states. Israel is not a signatory. Article 10 of the NPT allows them to withdraw if they so choose. they have not done so. a whole new country which explicitly has a right to withdraw from the NPT and has not chosen to do so. It\\'s more accurate, I think, to say that the problem with Iran here from a legal standpoint is that they aren\\'t honoring their own commitments, rather than that they\\'re building weapons. They could pull out of the NPT at any time, and the ball would be essentially in America\\'s court, because their nuclear program would no longer be illegal by international legal standards. However, Iran insists both on developing nukes *and* remaining an NPT signatory non-nuclear state, and that\\'s what makes their program illegal. I\\'d also like to clarify that I\\'m not making an ethical argument here, this is just how international law currently works. because international law doesn\\'t require states to sign treaties, it only requires them to adhere to treaties they\\'ve already signed. Israel isn\\'t defying the UN, at least not in this particular case. Think of the NPT less like a standard law within a state and more like a contract. Once you\\'ve signed, you\\'re bound by the contract, but if you never sign it then you haven\\'t broken a law, you\\'ve just decided not to agree to the terms you were offered. > Because Iran did sign the treaty, and thus are bound by it. They signed on July 1, 1968. Hmm. So is the argument here that it\\'s not \"ok\" for Iran to have a nuke, since they signed treaty not to do so. But it\\'s \"ok\" for Israel to have one because they never signed such thing? Can\\'t quite put my finger on it, but doesn\\'t seem quite right this one.'" + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_pros[1]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 80, + "outputs": [], + "source": [ + "# Place it all into a Pandas Dataframe\n", + "clean_df = pd.DataFrame({\n", + " \"INSTRUCTION\": all_titles,\n", + " \"RESPONSE\": all_pros,\n", + " \"SOURCE\": all_sources\n", + "}, index=all_names\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Apache Paquete file\n", + "\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "\n", + "table = pa.Table.from_pandas(clean_df)\n", + "pq.write_table(table,\"output.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "outputs": [ + { + "data": { + "text/plain": " INSTRUCTION \\\n0 Change my mind: I shouldn't get a job in this ... \n1 Change my mind: Iran has the right to develop ... \n2 Change my mind: The events in Paris suck...but... \n3 Change my mind: It is ok to hate a religion so... \n4 Change my mind: There is no productive reason ... \n5 Change my mind: Diet soda is perfectly healthy... \n6 Change my mind:Essential Oils are bullshit My ... \n7 Change my mind: I think the Paris shooting mak... \n8 Change my mind: Printing an image of the Musli... \n9 Change my mind: Philosophy has no tangible val... \n\n RESPONSE \\\n0 That is what someone in the 1500s would have s... \n1 it's already been signed. They even claim to b... \n2 Hm I guess I made the OP incorrectly. The mai... \n3 I don't understand your analogy. Promoting a ... \n4 ∆ I hadn't thought it from a \"let's trick peop... \n5 Thanks for a fresh argument! I hadn't conside... \n6 Most do. Some smell kinda funky. \n7 I already said in different comments that thi... \n8 The first bacon sandwich came about because 9... \n9 >Why restrict it to 50 years? I can name all s... \n\n SOURCE \n0 https://reddit.com/r/changemyview/comments/t3_... \n1 https://reddit.com/r/changemyview/comments/t3_... \n2 https://reddit.com/r/changemyview/comments/t3_... \n3 https://reddit.com/r/changemyview/comments/t3_... \n4 https://reddit.com/r/changemyview/comments/t3_... \n5 https://reddit.com/r/changemyview/comments/t3_... \n6 https://reddit.com/r/changemyview/comments/t3_... \n7 https://reddit.com/r/changemyview/comments/t3_... \n8 https://reddit.com/r/changemyview/comments/t3_... \n9 https://reddit.com/r/changemyview/comments/t3_... ", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
INSTRUCTIONRESPONSESOURCE
0Change my mind: I shouldn't get a job in this ...That is what someone in the 1500s would have s...https://reddit.com/r/changemyview/comments/t3_...
1Change my mind: Iran has the right to develop ...it's already been signed. They even claim to b...https://reddit.com/r/changemyview/comments/t3_...
2Change my mind: The events in Paris suck...but...Hm I guess I made the OP incorrectly. The mai...https://reddit.com/r/changemyview/comments/t3_...
3Change my mind: It is ok to hate a religion so...I don't understand your analogy. Promoting a ...https://reddit.com/r/changemyview/comments/t3_...
4Change my mind: There is no productive reason ...∆ I hadn't thought it from a \"let's trick peop...https://reddit.com/r/changemyview/comments/t3_...
5Change my mind: Diet soda is perfectly healthy...Thanks for a fresh argument! I hadn't conside...https://reddit.com/r/changemyview/comments/t3_...
6Change my mind:Essential Oils are bullshit My ...Most do. Some smell kinda funky.https://reddit.com/r/changemyview/comments/t3_...
7Change my mind: I think the Paris shooting mak...I already said in different comments that thi...https://reddit.com/r/changemyview/comments/t3_...
8Change my mind: Printing an image of the Musli...The first bacon sandwich came about because 9...https://reddit.com/r/changemyview/comments/t3_...
9Change my mind: Philosophy has no tangible val...>Why restrict it to 50 years? I can name all s...https://reddit.com/r/changemyview/comments/t3_...
\n
" + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test to see if it was sucessful\n", + "table = pq.read_table(\"output.parquet\")\n", + "table.to_pandas()" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From 5d4f74f9d60bb27c3d29d3c50f3b855869d59861 Mon Sep 17 00:00:00 2001 From: MattAlexMiracle Date: Thu, 26 Jan 2023 10:50:25 +0100 Subject: [PATCH 018/101] Ranked pairs (#933) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * commented out legacy numerical solver * added comments and task_scheduling for selecting which task to serve to users * removed standalone task weighting * pre-commit hook rerun * fixed ranking * fix index error * ranking fix * fix typo Co-authored-by: Alexander Mattick Co-authored-by: Andreas Köpf --- backend/oasst_backend/utils/ranking.py | 33 ++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/backend/oasst_backend/utils/ranking.py b/backend/oasst_backend/utils/ranking.py index 5538d7a3..0bb94fe8 100644 --- a/backend/oasst_backend/utils/ranking.py +++ b/backend/oasst_backend/utils/ranking.py @@ -96,13 +96,15 @@ def ranked_pairs(ranks: List[List[int]]): """ tallies, names = head_to_head_votes(ranks) tallies = tallies - tallies.T - # print(tallies) # note: the resulting tally matrix should be skew-symmetric # order by strength of victory (using tideman's original method, don't think it would make a difference for us) sorted_majorities = [] for i in range(len(ranks[0])): for j in range(len(ranks[0])): - if tallies[i, j] > 0: + # you can never prefer yourself over yourself + # we also have to pick one of the two choices, + # if the preference is exactly zero... + if tallies[i, j] >= 0 and i != j: sorted_majorities.append((i, j, tallies[i, j])) # we don't explicitly deal with tied majorities here sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True)) @@ -128,13 +130,36 @@ def ranked_pairs(ranks: List[List[int]]): if __name__ == "__main__": - ranks = ( + + ranks = """ ( [("w", "x", "z", "y") for _ in range(1)] + [("w", "y", "x", "z") for _ in range(2)] # + [("x","y","z","w") for _ in range(4)] + [("x", "z", "w", "y") for _ in range(5)] + [("y", "w", "x", "z") for _ in range(1)] # [("y","z","w","x") for _ in range(1000)] - ) + )""" + ranks = [ + [ + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ], + [ + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ], + [ + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ], + [ + ("d11705af-5575-43e5-b22e-08d155fbaa62"), + ("c5181083-d3e9-41e7-a935-83fb9fa01488"), + ("dcf3d179-0f34-4c15-ae21-b8feb15e422d"), + ], + ] rp = ranked_pairs(ranks) print(rp) From c2fa476904552ceca5675568f7645cae22de26fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 15:29:54 +0100 Subject: [PATCH 019/101] Add user emoji augmentation for message queries (#937) * add disposition to text labeling tasks * add emoji stats to ConversationMessage * add user emoji augmentation for message queries * add auth_method,username to message queries (query emoji status) * add auth_method+username for single message * fix param name typo * only join rows when message.emojis != JSON.NULL * formatting * make sure emojis and user_emojis default to {}, [] * remove init_user(), use fresh empty default collections --- .../oasst_backend/api/v1/frontend_users.py | 2 +- backend/oasst_backend/api/v1/messages.py | 60 ++++++++++++----- backend/oasst_backend/api/v1/users.py | 2 +- backend/oasst_backend/api/v1/utils.py | 5 +- backend/oasst_backend/models/message.py | 17 ++++- backend/oasst_backend/prompt_repository.py | 67 ++++++++++++++++--- backend/oasst_backend/tree_manager.py | 8 +++ oasst-shared/oasst_shared/schemas/protocol.py | 28 +++++--- 8 files changed, 149 insertions(+), 40 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 5ea7b26c..86f78026 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -77,7 +77,7 @@ def query_frontend_user_messages( """ Query frontend user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index af3ae42d..b3aace40 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -34,7 +34,7 @@ def query_messages( """ Query messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.query_messages_ordered_by_created_date( auth_method=auth_method, username=username, @@ -93,7 +93,7 @@ def get_messages_cursor( qry_max_count = max_count + 1 if before is None or after is None else max_count - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username, user_id=user_id) items = pr.query_messages_ordered_by_created_date( user_id=user_id, auth_method=auth_method, @@ -137,37 +137,49 @@ def get_messages_cursor( @router.get("/{message_id}", response_model=protocol.Message) def get_message( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a message by its internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) return utils.prepare_message(message) @router.get("/{message_id}/conversation", response_model=protocol.Conversation) def get_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a conversation from the tree root and up to the message with given internal ID. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_conversation(message_id) return utils.prepare_conversation(messages) @router.get("/{message_id}/tree", response_model=protocol.MessageTree) def get_tree( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False) return utils.prepare_tree(tree, message.message_tree_id) @@ -175,24 +187,32 @@ def get_tree( @router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get all messages belonging to the same message tree. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) messages = pr.fetch_message_children(message_id) return utils.prepare_message_list(messages) @router.get("/{message_id}/descendants", response_model=protocol.MessageTree) def get_descendants( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get a subtree which starts with this message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) descendants = pr.fetch_message_descendants(message) return utils.prepare_tree(descendants, message.id) @@ -200,12 +220,16 @@ def get_descendants( @router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation) def get_longest_conv( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get the longest conversation from the tree of the message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) conv = pr.fetch_longest_conversation(message.message_tree_id) return utils.prepare_conversation(conv) @@ -213,12 +237,16 @@ def get_longest_conv( @router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree) def get_max_children( - message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) + message_id: UUID, + auth_method: Optional[str] = None, + username: Optional[str] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), ): """ Get message with the most children from the tree of the provided message. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, auth_method=auth_method, username=username) message = pr.fetch_message(message_id) message, children = pr.fetch_message_with_max_children(message.message_tree_id) return utils.prepare_tree([message, *children], message.id) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index c0055339..d7497610 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -230,7 +230,7 @@ def query_user_messages( """ Query user messages. """ - pr = PromptRepository(db, api_client) + pr = PromptRepository(db, api_client, user_id=user_id) messages = pr.query_messages_ordered_by_created_date( user_id=user_id, api_client_id=api_client_id, diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 8b0f378f..5c9537a3 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -14,7 +14,8 @@ def prepare_message(m: Message) -> protocol.Message: lang=m.lang, is_assistant=(m.role == "assistant"), created_date=m.created_date, - emojis=m.emojis, + emojis=m.emojis or {}, + user_emojis=m.user_emojis or [], ) @@ -30,6 +31,8 @@ def prepare_conversation_message_list(messages: list[Message]) -> list[protocol. text=message.text, lang=message.lang, is_assistant=(message.role == "assistant"), + emojis=message.emojis or {}, + user_emojis=message.user_emojis or [], ) for message in messages ] diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index da0c06c3..5f323d5d 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,12 +1,13 @@ from datetime import datetime from http import HTTPStatus -from typing import Optional +from typing import Any, Optional from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg from oasst_backend.models.db_payload import MessagePayload from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from pydantic import PrivateAttr from sqlalchemy import false from sqlmodel import Field, Index, SQLModel @@ -17,6 +18,13 @@ class Message(SQLModel, table=True): __tablename__ = "message" __table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),) + def __new__(cls, *args: Any, **kwargs: Any): + new_object = super().__new__(cls, *args, **kwargs) + # temporary fix until https://github.com/tiangolo/sqlmodel/issues/149 gets merged + if not hasattr(new_object, "_user_emojis"): + new_object._init_private_attributes() + return new_object + id: Optional[UUID] = Field( sa_column=sa.Column( pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") @@ -49,7 +57,8 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) - emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) + emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False) + _user_emojis: Optional[list[str]] = PrivateAttr(default=None) def ensure_is_message(self) -> None: if not self.payload or not isinstance(self.payload.payload, MessagePayload): @@ -59,3 +68,7 @@ class Message(SQLModel, table=True): def text(self) -> str: self.ensure_is_message() return self.payload.payload.text + + @property + def user_emojis(self) -> str: + return self._user_emojis diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 7dddb5cf..b31b53d7 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -30,8 +30,9 @@ from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats from oasst_shared.utils import unaware_to_utc +from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified -from sqlmodel import Session, and_, func, not_, or_, text, update +from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -41,14 +42,25 @@ class PromptRepository: db: Session, api_client: ApiClient, client_user: Optional[protocol_schema.User] = None, + *, user_repository: Optional[UserRepository] = None, task_repository: Optional[TaskRepository] = None, + user_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + username: Optional[str] = None, ): self.db = db self.api_client = api_client self.user_repository = user_repository or UserRepository(db, api_client) - self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) - self.user_id = self.user.id if self.user else None + if user_id: + self.user = self.user_repository.get_user(id=user_id) + self.user_id = self.user.id + elif auth_method and username: + self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username) + self.user_id = self.user.id + else: + self.user = self.user_repository.lookup_client_user(client_user, create_missing=True) + self.user_id = self.user.id if self.user else None logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})") self.task_repository = task_repository or TaskRepository( db, api_client, client_user, user_repository=self.user_repository @@ -529,7 +541,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_user_message_trees( self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False @@ -539,7 +551,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if not include_deleted: qry = qry.filter(not_(Message.deleted)) - return qry.all() + return self._add_user_emojis_all(qry) def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]: qry = self.db.query(MessageTreeState).filter( @@ -582,6 +594,10 @@ class PromptRepository: return conversation, replies def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: + qry = self.db.query(Message).filter(Message.id == message_id) + messages = self._add_user_emojis_all(qry) + message = messages[0] if messages else None + message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) @@ -656,7 +672,7 @@ class PromptRepository: qry = qry.filter(Message.review_result) if exclude_deleted: qry = qry.filter(Message.deleted == sa.false()) - children = qry.all() + children = self._add_user_emojis_all(qry) return children def fetch_message_siblings( @@ -674,7 +690,7 @@ class PromptRepository: qry = qry.filter(Message.review_result == reviewed) if deleted is not None: qry = qry.filter(Message.deleted == deleted) - siblings = qry.all() + siblings = self._add_user_emojis_all(qry) return siblings @staticmethod @@ -705,7 +721,7 @@ class PromptRepository: if max_depth is not None: desc = desc.filter(Message.depth <= max_depth) - desc = desc.all() + desc = self._add_user_emojis_all(desc) return self.trace_descendants(message, desc) @@ -719,6 +735,33 @@ class PromptRepository: max_message = max(tree, key=lambda m: m.children_count) return max_message, [m for m in tree if m.parent_id == max_message.id] + def _add_user_emojis_all(self, qry: Query) -> list[Message]: + if self.user_id is None: + return qry.all() + + sq = qry.subquery("m") + qry = ( + self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis")) + .select_entity_from(sq) + .outerjoin( + MessageEmoji, + and_( + sq.c.id == MessageEmoji.message_id, + MessageEmoji.user_id == self.user_id, + sq.c.emojis != JSON.NULL, + ), + ) + .group_by(sq) + ) + messages: list[Message] = [] + for x in qry: + m: Message = x.Message + user_emojis = x["user_emojis"] + if user_emojis: + m._user_emojis = user_emojis.split(",") + messages.append(m) + return messages + def query_messages_ordered_by_created_date( self, user_id: Optional[UUID] = None, @@ -801,7 +844,7 @@ class PromptRepository: if lang is not None: qry = qry.filter(Message.lang == lang) - return qry.all() + return self._add_user_emojis_all(qry) def update_children_counts(self, message_tree_id: UUID): sql_update_children_count = """ @@ -902,9 +945,15 @@ WHERE message.id = cc.id; else: count = emoji_counts.get(emoji.value) or 0 emoji_counts[emoji.value] = count + 1 + if message._user_emojis is None: + message._user_emojis = [] + if emoji.value not in message._user_emojis: + message._user_emojis.append(emoji.value) elif op == protocol_schema.EmojiOp.remove: # remove emoji record and & decrement count message = self.fetch_message(message_id) + if message._user_emojis and emoji.value in message._user_emojis: + message._user_emojis.remove(emoji.value) self.db.delete(existing_emoji) emoji_counts = message.emojis count = emoji_counts.get(emoji.value) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 64e1883e..929a9297 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -354,6 +354,7 @@ class TreeManager: self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if message.role == "assistant": @@ -363,6 +364,8 @@ class TreeManager: ): valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, @@ -371,6 +374,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, + disposition=label_disposition, ) else: if ( @@ -387,6 +391,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id @@ -424,11 +429,13 @@ class TreeManager: message = random.choice(prompts_need_review) label_mode = protocol_schema.LabelTaskMode.full + label_disposition = protocol_schema.LabelTaskDisposition.quality valid_labels = self._all_text_labels if random.random() > self.cfg.p_full_labeling_review_prompt: valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) label_mode = protocol_schema.LabelTaskMode.simple + label_disposition = protocol_schema.LabelTaskDisposition.spam logger.info(f"Generating a LabelInitialPromptTask ({label_mode=:s}).") task = protocol_schema.LabelInitialPromptTask( @@ -437,6 +444,7 @@ class TreeManager: valid_labels=valid_labels, mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, + disposition=label_disposition, ) parent_message_id = message.id diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index de431d75..31caa340 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -57,6 +57,8 @@ class ConversationMessage(BaseModel): text: str lang: Optional[str] # BCP 47 is_assistant: bool + emojis: Optional[dict[str, int]] = None + user_emojis: Optional[list[str]] = None class Conversation(BaseModel): @@ -80,7 +82,6 @@ class Conversation(BaseModel): class Message(ConversationMessage): parent_id: Optional[UUID] = None created_date: Optional[datetime] = None - emojis: Optional[dict] = None class MessagePage(PageResult): @@ -223,27 +224,34 @@ class LabelTaskMode(str, enum.Enum): full = "full" -class LabelInitialPromptTask(Task): - """A task to label an initial prompt.""" +class LabelTaskDisposition(str, enum.Enum): + """Reason why the task was issued.""" - type: Literal["label_initial_prompt"] = "label_initial_prompt" + quality = "quality" + spam = "spam" + + +class AbstractLabelTask(Task): message_id: UUID - prompt: str valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] + disposition: Optional[LabelTaskDisposition] -class LabelConversationReplyTask(Task): +class LabelInitialPromptTask(AbstractLabelTask): + """A task to label an initial prompt.""" + + type: Literal["label_initial_prompt"] = "label_initial_prompt" + prompt: str + + +class LabelConversationReplyTask(AbstractLabelTask): """A task to label a reply to a conversation.""" type: Literal["label_conversation_reply"] = "label_conversation_reply" conversation: Conversation # the conversation so far - message_id: UUID reply: str - valid_labels: list[str] - mandatory_labels: Optional[list[str]] - mode: Optional[LabelTaskMode] class LabelPrompterReplyTask(LabelConversationReplyTask): From d4688835d54a0da51c986cf2b43cfbe14fdfdb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 16:33:03 +0100 Subject: [PATCH 020/101] check condition for scoring on startup --- backend/oasst_backend/tree_manager.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 929a9297..a4cad0c5 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -542,9 +542,7 @@ class TreeManager: ) _, task = pr.store_ranking(interaction) - - ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id) - self.update_message_ranks(task.message_tree_id, rankings_by_message) + self.check_condition_for_scoring_state(task.message_tree_id) case protocol_schema.TextLabels: logger.info( @@ -659,7 +657,8 @@ class TreeManager: return False, None self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) - return True, rankings_by_message + self.update_message_ranks(message_tree_id, rankings_by_message) + return True def update_message_ranks( self, message_tree_id: UUID, rankings_by_message: dict[UUID, list[MessageReaction]] @@ -976,7 +975,7 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki return rankings_by_message @managed_tx_method(CommitMode.COMMIT) - def ensure_tree_states(self): + def ensure_tree_states(self) -> None: """Add message tree state rows for all root nodes (inital prompt messages).""" missing_tree_ids = self.query_misssing_tree_states() @@ -988,6 +987,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})") self._insert_default_state(id, state=state) + rankings = ( + self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all() + ) + if len(rankings) > 0: + logger.info(f"Checking state of {len(rankings)} message trees in ranking state.") + for r in rankings: + self.check_condition_for_scoring_state(r.message_tree_id) + def query_num_active_trees(self, lang: str) -> int: query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) From f1edcc8a285dbc184a14ab50dccc39d258d45c92 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 16:41:57 +0100 Subject: [PATCH 021/101] added streaming worker --- inference/README.md | 7 ++ inference/worker/__main__.py | 66 +++++++++++++------ inference/worker/requirements.txt | 4 +- .../oasst_shared/schemas/inference.py | 4 ++ 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/inference/README.md b/inference/README.md index 3dee94f9..bd0272ad 100644 --- a/inference/README.md +++ b/inference/README.md @@ -26,6 +26,13 @@ pip install -r requirements.txt python __main__.py ``` +For the worker, you'll also want to have the text-generation-inference server +running: + +```bash +docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference +``` + Run the client: ```bash diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index ad5e5cef..c8c1a4c9 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -1,13 +1,12 @@ -import re -import time +import json import rel -import torch +import requests +import sseclient import typer import websocket from loguru import logger from oasst_shared.schemas import inference, protocol -from transformers import pipeline app = typer.Typer() @@ -16,9 +15,8 @@ app = typer.Typer() def main( backend_url: str = "ws://localhost:8000", model_name: str = "distilgpt2", + inference_server_url: str = "http://localhost:8001", ): - pipe = pipeline("text-generation", model=model_name) - def on_open(ws: websocket.WebSocket): worker_config = inference.WorkerConfig(model_name=model_name) ws.send(worker_config.json()) @@ -37,23 +35,49 @@ def main( prompt = "\n".join(messages) + "\nAssistant:" - # TODO: replace this with incremental generation - torch.manual_seed(work_request.seed) - model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ - 0 - ]["generated_text"] - model_output = model_output.strip() + # TODO: use the seed + # torch.manual_seed(work_request.seed) + # model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ + # 0 + # ]["generated_text"] + # model_output = model_output.strip() - # fake streaming - split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] - pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] - for piece in pieces: - if not piece: - continue - if piece.strip() in ("User:", "Assistant:"): + # # fake streaming + # split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] + # pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] + # for piece in pieces: + # if not piece: + # continue + # if piece.strip() in ("User:", "Assistant:"): + # break + # ws.send(inference.WorkResponsePacket(token=piece).json()) + # time.sleep(0.1) + # ws.send(inference.WorkResponsePacket(is_end=True).json()) + + response = requests.post( + f"{inference_server_url}/generate_stream", + json={ + "inputs": prompt, + "parameters": { + "max_new_tokens": work_request.max_new_tokens, + "do_sample": work_request.do_sample, + "top_k": work_request.top_k, + "top_p": work_request.top_p, + "temperature": work_request.temperature, + }, + }, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() + + client = sseclient.SSEClient(response) + for event in client.events(): + data = json.loads(event.data) + if data["is_end"]: break - ws.send(inference.WorkResponsePacket(token=piece).json()) - time.sleep(0.1) + intermediate = data["event"] + ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json()) ws.send(inference.WorkResponsePacket(is_end=True).json()) def on_error(ws: websocket.WebSocket, error: Exception): diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt index c248c652..82169379 100644 --- a/inference/worker/requirements.txt +++ b/inference/worker/requirements.txt @@ -1,6 +1,6 @@ loguru rel -torch -transformers +requests +sseclient-py typer websocket-client diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 0acb5014..b50cef9c 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -14,6 +14,10 @@ class WorkRequest(pydantic.BaseModel): model_name: str = "distilgpt2" max_new_tokens: int = 100 seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + do_sample: bool = True + top_k: int = 50 + top_p: float = 0.9 + temperature: float = 1.0 class WorkResponsePacket(pydantic.BaseModel): From 348999a93636bfcadb72256175a0a7115ff0cf63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 19:06:25 +0100 Subject: [PATCH 022/101] exclude trees in ranking state in acitve tree count --- backend/oasst_backend/tree_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a4cad0c5..992e75dd 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -202,7 +202,7 @@ class TreeManager: lang = "en" logger.warning("Task availability request without lang tag received, assuming lang='en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) extendible_parents = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) @@ -230,7 +230,7 @@ class TreeManager: lang = "en" logger.warning("Task request without lang tag received, assuming 'en'.") - num_active_trees = self.query_num_active_trees(lang=lang) + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) extendible_parents = self.query_extendible_parents(lang=lang) @@ -995,12 +995,15 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki for r in rankings: self.check_condition_for_scoring_state(r.message_tree_id) - def query_num_active_trees(self, lang: str) -> int: + def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int: + """Count all active trees (optionally exclude those in ranking state).""" query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) .join(Message, MessageTreeState.message_tree_id == Message.id) .filter(MessageTreeState.active, Message.lang == lang) ) + if exclude_ranking: + query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: From 040344a41f8beef9c0a3ad4e82474326186b121f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 21:01:52 +0100 Subject: [PATCH 023/101] made inference server a bit more robust --- inference/server/main.py | 88 ++++++++++++++++++++++-------------- inference/worker/__main__.py | 8 +++- 2 files changed, 60 insertions(+), 36 deletions(-) diff --git a/inference/server/main.py b/inference/server/main.py index f3ec02b1..4cb5f659 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -5,6 +5,7 @@ import uuid import fastapi import pydantic import redis.asyncio as redis +import websockets.exceptions from fastapi.middleware.cors import CORSMiddleware from loguru import logger from oasst_shared.schemas import inference, protocol @@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum): pending = "pending" in_progress = "in_progress" complete = "complete" + aborted_by_worker = "aborted_by_worker" class DbChatEntry(pydantic.BaseModel): @@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque async def work(websocket: fastapi.WebSocket): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) - while True: - # find a pending task that matches the worker's config - # could also be implemented using task queues - # but general compatibility matching is tricky - for chat in CHATS.values(): - if (request := chat.pending_message_request) is not None: - if chat.message_request_state == MessageRequestState.pending: - if request.compatible_with(worker_config): + try: + while True: + print(websocket.client_state) + if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: + logger.warning("Worker disconnected") + break + # find a pending task that matches the worker's config + # could also be implemented using task queues + # but general compatibility matching is tricky + for chat in CHATS.values(): + if (request := chat.pending_message_request) is not None: + if chat.message_request_state == MessageRequestState.pending: + if request.compatible_with(worker_config): + break + else: + logger.debug("No pending tasks") + await asyncio.sleep(1) + continue + + chat.message_request_state = MessageRequestState.in_progress + + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) + + logger.info(f"Created {work_request}") + try: + await websocket.send_text(work_request.json()) + except websockets.exceptions.ConnectionClosedError: + logger.warning("Worker disconnected") + websocket.close() + chat.message_request_state = MessageRequestState.pending + break + + try: + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + await redisClient.rpush(chat.id, response_packet.json()) + if response_packet.is_end: break - else: - logger.debug("No pending tasks") - await asyncio.sleep(1) - continue + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + chat.message_request_state = MessageRequestState.aborted_by_worker + raise - chat.message_request_state = MessageRequestState.in_progress - - work_request = inference.WorkRequest( - conversation=chat.conversation, - model_name=request.model_name, - max_new_tokens=request.max_new_tokens, - ) - - logger.info(f"Created {work_request}") - try: - await websocket.send_text(work_request.json()) - while True: - # maybe unnecessary to parse and re-serialize - # could just pass the raw string and mark end via empty string - response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) - await redisClient.rpush(chat.id, response_packet.json()) - if response_packet.is_end: - break - except fastapi.WebSocketException: - # TODO: handle this better - logger.exception(f"Websocket closed during handling of {chat.id}") - - chat.message_request_state = MessageRequestState.complete + chat.message_request_state = MessageRequestState.complete + except fastapi.WebSocketException: + logger.exception("Websocket closed") diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index c8c1a4c9..e5c15fb4 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -33,7 +33,13 @@ def main( # construct prompt messages = [_prepare_message(message) for message in work_request.conversation.messages] - prompt = "\n".join(messages) + "\nAssistant:" + prefix = ( + "The following is a conversation between a user and an assistant. " + "The assistant is helpful, creative, clever, and very friendly.\n" + "Assistant: Hello! How can I help you today?\n" + ) + + prompt = prefix + "\n".join(messages) + "\nAssistant:" # TODO: use the seed # torch.manual_seed(work_request.seed) From f3ffde47ffc1cf2f49a8f835cf2c1a4a38ebcc9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 26 Jan 2023 23:00:54 +0100 Subject: [PATCH 024/101] add preferred lonely_children extension (#942) * add preferred lonely_children extension * simplify sampling process, lower the probability to 25% * exclude parents for replies that were recently used * lonely children := count > 0 * consider only tasks not done for parent exclusion * increase lonely child sampling probability --- ...84fcd6900dc_add_task_created_date_index.py | 26 ++++++++++++++ backend/oasst_backend/config.py | 9 +++++ backend/oasst_backend/models/task.py | 4 ++- backend/oasst_backend/task_repository.py | 16 ++++++++- backend/oasst_backend/tree_manager.py | 34 ++++++++++++++++--- 5 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py diff --git a/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py new file mode 100644 index 00000000..29fd1aec --- /dev/null +++ b/backend/alembic/versions/2023_01_26_1835-c84fcd6900dc_add_task_created_date_index.py @@ -0,0 +1,26 @@ +"""add task created date index + +Revision ID: c84fcd6900dc +Revises: 40ed93df0ed5 +Create Date: 2023-01-26 18:35:43.061589 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c84fcd6900dc" +down_revision = "40ed93df0ed5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_task_created_date"), "task", ["created_date"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_task_created_date"), table_name="task") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index d831eca2..9952c654 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -57,6 +57,15 @@ class TreeManagerConfiguration(BaseModel): rank_prompter_replies: bool = False + lonely_children_count: int = 3 + """Number of children below which parents are preferred during sampling for reply tasks.""" + + p_lonely_child_extension: float = 0.8 + """Probability to select a parent with less than lonely_children_count children.""" + + recent_tasks_span_sec: int = 3 * 60 # 3 min + """Time in seconds of recent tasks to consider for exclusion during task selection.""" + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index a59f689e..7f91b157 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -20,7 +20,9 @@ class Task(SQLModel, table=True): ), ) created_date: Optional[datetime] = Field( - sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()), + sa_column=sa.Column( + sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.current_timestamp() + ), ) expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index eb100fe3..5c5dea21 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Optional from uuid import UUID @@ -9,7 +10,7 @@ from oasst_backend.user_repository import UserRepository from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, func, or_ from starlette.status import HTTP_404_NOT_FOUND @@ -219,3 +220,16 @@ class TaskRepository: def fetch_task_by_id(self, task_id: UUID) -> Task: task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none() return task + + def fetch_recent_reply_tasks( + self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100 + ) -> list[Task]: + qry = self.db.query(Task).filter( + func.age(Task.created_date) < max_age, + or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"), + ) + if done is not None: + qry = qry.filter(Task.done == done) + if limit: + qry = qry.limit(limit) + return qry.all() diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 992e75dd..89a51807 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -1,7 +1,7 @@ import json import random import sys -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple @@ -339,6 +339,7 @@ class TreeManager: message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_REPLY: + if task_role == TaskRole.PROMPTER: replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review)) elif task_role == TaskRole.ASSISTANT: @@ -398,19 +399,44 @@ class TreeManager: message_tree_id = message.message_tree_id case TaskType.REPLY: - # select a tree with missing replies + + recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks( + max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False + ) + recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks} + if task_role == TaskRole.PROMPTER: extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) elif task_role == TaskRole.ASSISTANT: extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents)) + # select a tree with missing replies if len(extendible_parents) > 0: - random_parent = random.choice(extendible_parents) + random_parent: ExtendibleParentRow = None + if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1: + # check if we have extendible parents with a small number of replies + + lonely_children_parents = [ + p + for p in extendible_parents + if 0 < p.active_children_count < self.cfg.lonely_children_count + and p.parent_id not in recent_reply_task_parents + ] + if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension: + random_parent = random.choice(lonely_children_parents) + + if random_parent is None: + # try to exclude parents for which tasks were recently handed out + fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents] + if len(fresh_parents) > 0: + random_parent = random.choice(fresh_parents) + else: + random_parent = random.choice(extendible_parents) # fetch random conversation to extend logger.debug(f"selected {random_parent=}") messages = self.pr.fetch_message_conversation(random_parent.parent_id) - assert all(m.review_result for m in messages) # ensure all messages have positive review + assert all(m.review_result for m in messages) # ensure all messages have positive reviews conversation = prepare_conversation(messages) # generate reply task depending on last message From da1c81d2c9ffaf748a94dddce1de8ea4a014b341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 00:54:29 +0100 Subject: [PATCH 025/101] Add LabelDescription list to labeling tasks, make +1/-1 emojis exclusive (#947) * add LabelDescription list to labeling tasks * make +1 & -1 emoji exclusive (only one of both or none) * add red_flag emoji to message when reported * fix task's valid labels * fix typo --- backend/oasst_backend/api/v1/text_labels.py | 29 +++++++++- backend/oasst_backend/config.py | 51 ++++++++++++++-- backend/oasst_backend/models/message.py | 6 ++ backend/oasst_backend/prompt_repository.py | 35 ++++++++--- backend/oasst_backend/schemas/text_labels.py | 11 +--- backend/oasst_backend/tree_manager.py | 33 +++++++---- oasst-shared/oasst_shared/schemas/protocol.py | 58 ++++++++++++------- 7 files changed, 169 insertions(+), 54 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index dc6cc889..2025fd4c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -3,10 +3,11 @@ from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository -from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse +from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -45,7 +46,29 @@ def label_text( def get_valid_lables() -> ValidLabelsResponse: return ValidLabelsResponse( valid_labels=[ - LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text) - for l in protocol_schema.TextLabel + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in TextLabel + ] + ) + + +@router.get("/report_labels") +def get_report_lables() -> ValidLabelsResponse: + report_labels = [ + TextLabel.spam, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + TextLabel.moral_judgement, + TextLabel.political_content, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.quality, + ] + return ValidLabelsResponse( + valid_labels=[ + LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) + for l in report_labels ] ) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 9952c654..157845d7 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TextLabel from pydantic import AnyHttpUrl, BaseModel, BaseSettings, FilePath, PostgresDsn, validator @@ -46,13 +46,56 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" - mandatory_labels_initial_prompt: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + labels_initial_prompt: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_assistant_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.fails_task, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.creativity, + TextLabel.humor, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + labels_prompter_reply: list[TextLabel] = [ + TextLabel.spam, + TextLabel.quality, + TextLabel.helpfulness, + TextLabel.humor, + TextLabel.creativity, + TextLabel.toxicity, + TextLabel.violence, + TextLabel.not_appropriate, + TextLabel.pii, + TextLabel.hate_speech, + TextLabel.sexual_content, + ] + + mandatory_labels_initial_prompt: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for initial prompts.""" - mandatory_labels_assistant_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_assistant_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for assistant replies.""" - mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] + mandatory_labels_prompter_reply: Optional[list[TextLabel]] = [TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" rank_prompter_replies: bool = False diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 5f323d5d..24fafc01 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -64,6 +64,12 @@ class Message(SQLModel, table=True): if not self.payload or not isinstance(self.payload.payload, MessagePayload): raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) + def has_emoji(self, emoji_code: str) -> bool: + return self.emojis and emoji_code in self.emojis and self.emojis[emoji_code] > 0 + + def has_user_emoji(self, emoji_code: str) -> bool: + return self._user_emojis and emoji_code in self._user_emojis + @property def text(self) -> str: self.ensure_is_message() diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b31b53d7..d3f655fc 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -461,15 +461,22 @@ class PromptRepository: ) if message_id: - message = self.fetch_message(message_id) - if task: + if not task: + if text_labels.is_report is True: + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag + ) + + # update existing record for repeated updates (same user no task associated) + existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) + if existing_text_label is not None: + existing_text_label.labels = text_labels.labels + model = existing_text_label + + else: + message = self.fetch_message(message_id) message.review_count += 1 self.db.add(message) - # for the same User id with no task id associated with the message, then update existing record for repeated updates - existing_text_label = self.fetch_non_task_text_labels(message_id, self.user_id) - if existing_text_label is not None: - existing_text_label.labels = text_labels.labels - model = existing_text_label self.db.add(model) return model, task, message @@ -936,6 +943,20 @@ WHERE message.id = cc.id; op = protocol_schema.EmojiOp.add if op == protocol_schema.EmojiOp.add: + # hard coded exclusivity of thumbs_up & thumbs_down + if emoji == protocol_schema.EmojiCode.thumbs_up and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_down.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_down + ) + elif emoji == protocol_schema.EmojiCode.thumbs_down and message.has_user_emoji( + protocol_schema.EmojiCode.thumbs_up.value + ): + message = self.handle_message_emoji( + message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up + ) + # insert emoji record & increment count message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji) self.db.add(message_emoji) diff --git a/backend/oasst_backend/schemas/text_labels.py b/backend/oasst_backend/schemas/text_labels.py index 9135c558..e846d8f4 100644 --- a/backend/oasst_backend/schemas/text_labels.py +++ b/backend/oasst_backend/schemas/text_labels.py @@ -1,13 +1,6 @@ -from typing import Optional - +from oasst_shared.schemas.protocol import LabelDescription from pydantic import BaseModel -class LabelOption(BaseModel): - name: str - display_text: str - help_text: Optional[str] - - class ValidLabelsResponse(BaseModel): - valid_labels: list[LabelOption] + valid_labels: list[LabelDescription] diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 89a51807..77419184 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -94,8 +94,6 @@ class TreeManagerStats(pydantic.BaseModel): class TreeManager: - _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) - def __init__( self, db: Session, @@ -216,6 +214,15 @@ class TreeManager: incomplete_rankings=incomplete_rankings, ) + @staticmethod + def _get_label_descriptions(valid_labels: list[TextLabels]) -> list[protocol_schema.LabelDescription]: + return [ + protocol_schema.LabelDescription( + name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text + ) + for l in valid_labels + ] + def next_task( self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random, @@ -356,14 +363,14 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels if message.role == "assistant": + valid_labels = self.cfg.labels_assistant_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_assistant ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + valid_labels = self.cfg.mandatory_labels_assistant_reply label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -372,27 +379,30 @@ class TreeManager: message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) else: + valid_labels = self.cfg.labels_prompter_reply if ( desired_task_type == protocol_schema.TaskRequestType.random and random.random() > self.cfg.p_full_labeling_review_reply_prompter ): - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + valid_labels = self.cfg.mandatory_labels_prompter_reply label_mode = protocol_schema.LabelTaskMode.simple logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, conversation=conversation, reply=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -456,10 +466,10 @@ class TreeManager: label_mode = protocol_schema.LabelTaskMode.full label_disposition = protocol_schema.LabelTaskDisposition.quality - valid_labels = self._all_text_labels + valid_labels = self.cfg.labels_initial_prompt if random.random() > self.cfg.p_full_labeling_review_prompt: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)) + valid_labels = self.cfg.mandatory_labels_initial_prompt label_mode = protocol_schema.LabelTaskMode.simple label_disposition = protocol_schema.LabelTaskDisposition.spam @@ -467,10 +477,11 @@ class TreeManager: task = protocol_schema.LabelInitialPromptTask( message_id=message.id, prompt=message.text, - valid_labels=valid_labels, + valid_labels=list(map(lambda x: x.value, valid_labels)), mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)), mode=label_mode, disposition=label_disposition, + labels=self._get_label_descriptions(valid_labels), ) parent_message_id = message.id @@ -577,7 +588,7 @@ class TreeManager: _, task, msg = pr.store_text_labels(interaction) - # if it was a respones for a task, check if we have enough reviews to calc review_result + # if it was a response for a task, check if we have enough reviews to calc review_result if task and msg: reviews = self.query_reviews_for_message(msg.id) acceptance_score = self._calculate_acceptance(reviews) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 31caa340..22a4adfb 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -231,12 +231,20 @@ class LabelTaskDisposition(str, enum.Enum): spam = "spam" +class LabelDescription(BaseModel): + name: str + widget: str + display_text: str + help_text: Optional[str] + + class AbstractLabelTask(Task): message_id: UUID valid_labels: list[str] mandatory_labels: Optional[list[str]] mode: Optional[LabelTaskMode] disposition: Optional[LabelTaskDisposition] + labels: Optional[list[LabelDescription]] class LabelInitialPromptTask(AbstractLabelTask): @@ -324,39 +332,48 @@ class MessageRanking(Interaction): ranking: conlist(item_type=int, min_items=1) +class LabelWidget(str, enum.Enum): + yes_no = "yes_no" + flag = "flag" + likert = "likert" + + class TextLabel(str, enum.Enum): """A label for a piece of text.""" - def __new__(cls, label: str, display_text: str = "", help_text: str = None): + def __new__(cls, label: str, widget: LabelWidget, display_text: str = "", help_text: str = None): obj = str.__new__(cls, label) obj._value_ = label + obj.widget = widget obj.display_text = display_text obj.help_text = help_text return obj - spam = "spam", "Seems to be intentionally low-quality or irrelevant" - fails_task = "fails_task", "Fails to follow the correct instruction / task" - not_appropriate = "not_appropriate", "Inappropriate for customer assistant" - violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm" - excessive_harm = ( - "excessive_harm", - "Content likely to cause excessive harm not justifiable in the context", - "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", - ) - sexual_content = "sexual_content", "Contains sexual content" - toxicity = "toxicity", "Contains rude, abusive, profane or insulting content" - moral_judgement = "moral_judgement", "Expresses moral judgement" - political_content = "political_content", "Expresses political views" - humor = "humor", "Contains humorous content including sarcasm" + # yes/no questions + spam = "spam", LabelWidget.yes_no, "Seems to be intentionally low-quality or irrelevant" + fails_task = "fails_task", LabelWidget.yes_no, "Fails to follow the correct instruction / task" + + # flags + pii = "pii", LabelWidget.flag, "Contains personal identifiable information (PII)" + not_appropriate = "not_appropriate", LabelWidget.flag, "Inappropriate" hate_speech = ( "hate_speech", + LabelWidget.flag, "Content is abusive or threatening and expresses prejudice against a protected characteristic", - "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + "Prejudice refers to preconceived views not based on reason. Protected characteristics " + "include gender, ethnicity, religion, sexual orientation, and similar characteristics.", ) - threat = "threat", "Contains a threat against a person or persons" - misleading = "misleading", "Contains text which is incorrect or misleading" - helpful = "helpful", "Completes the task to a high standard" - creative = "creative", "Expresses creativity in responding to the task" + sexual_content = "sexual_content", LabelWidget.flag, "Contains sexual content" + moral_judgement = "moral_judgement", LabelWidget.flag, "Expresses moral judgement" + political_content = "political_content", LabelWidget.flag, "Expresses political views" + + # likert + quality = "quality", LabelWidget.likert, "Overall subjective quality rating of the message" + toxicity = "toxicity", LabelWidget.likert, "Rude, abusive, profane or insulting content" + humor = "humor", LabelWidget.likert, "Humorous content including sarcasm" + helpfulness = "helpfulness", LabelWidget.likert, "Helpfulness of the message" + creativity = "creativity", LabelWidget.likert, "Creativity" + violence = "violence", LabelWidget.likert, "Violence/abuse/terrorism/self-harm" class TextLabels(Interaction): @@ -367,6 +384,7 @@ class TextLabels(Interaction): labels: dict[TextLabel, float] message_id: UUID task_id: Optional[UUID] + is_report: Optional[bool] @property def has_message_id(self) -> bool: From 1011fa527c97e7b0f04aea905a588279177420fa Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Fri, 27 Jan 2023 13:25:07 +1100 Subject: [PATCH 026/101] website: Fix handling of new and unknown labels --- .../src/components/Survey/LabelInputGroup.tsx | 288 ++++++++++-------- 1 file changed, 167 insertions(+), 121 deletions(-) diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelInputGroup.tsx index 06e0888b..9a06b6ca 100644 --- a/website/src/components/Survey/LabelInputGroup.tsx +++ b/website/src/components/Survey/LabelInputGroup.tsx @@ -18,126 +18,172 @@ interface LabelInfo { inverted: boolean; } -// This should be moved to the valid labels api endpoint -const label_messages: { - [label: string]: LabelInfo; -} = { - spam: { - zeroText: "Not Spam", - zeroDescription: ["Suitable for training Open Assistant."], - oneText: "Spam", - oneDescription: [ - "Seems to be intentionally low-quality or irrelevant", - 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', - "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", - "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", - ], - inverted: true, - }, - fails_task: { - zeroText: "Follows Instructions", - zeroDescription: [], - oneText: "Fails Task", - oneDescription: ["Fails to follow the correct instruction / task"], - inverted: true, - }, - not_appropriate: { - zeroText: "Appropriate", - zeroDescription: [], - oneText: "Inappropriate", - oneDescription: ["Inappropriate for customer assistant"], - inverted: true, - }, - violence: { - zeroText: "Harmless", - zeroDescription: [], - oneText: "Violent", - oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], - inverted: true, - }, - excessive_harm: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Harmful", - oneDescription: [ - "Content likely to cause excessive harm not justifiable in the context", - "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", - ], - inverted: true, - }, - sexual_content: { - zeroText: "Non Sexual", - zeroDescription: [], - oneText: "Sexual", - oneDescription: ["Contains sexual content"], - inverted: true, - }, - toxicity: { - zeroText: "Polite", - zeroDescription: [], - oneText: "Rude", - oneDescription: ["Contains rude, abusive, profane or insulting content"], - inverted: true, - }, - moral_judgement: { - zeroText: "Non-Judgemental", - zeroDescription: [], - oneText: "Judgemental", - oneDescription: ["Expresses moral judgement"], - inverted: true, - }, - political_content: { - zeroText: "Apolitical", - zeroDescription: [], - oneText: "Political", - oneDescription: ["Expresses political views"], - inverted: true, - }, - humor: { - zeroText: "Serious", - zeroDescription: [], - oneText: "Humorous", - oneDescription: ["Contains humorous content including sarcasm"], - inverted: false, - }, - hate_speech: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Hateful", - oneDescription: [ - "Content is abusive or threatening and expresses prejudice against a protected characteristic", - "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", - ], - inverted: true, - }, - threat: { - zeroText: "Safe", - zeroDescription: [], - oneText: "Threatening", - oneDescription: ["Contains a threat against a person or persons"], - inverted: true, - }, - misleading: { - zeroText: "Accurate", - zeroDescription: [], - oneText: "Misleading", - oneDescription: ["Contains text which is incorrect or misleading"], - inverted: true, - }, - helpful: { - zeroText: "Unhelful", - zeroDescription: [], - oneText: "Helpful", - oneDescription: ["Completes the task to a high standard"], - inverted: false, - }, - creative: { - zeroText: "Boring", - zeroDescription: [], - oneText: "Creative", - oneDescription: ["Expresses creativity in responding to the task"], - inverted: false, - }, +const getLabelInfo = (label: string): LabelInfo => { + switch (label) { + case "spam": + return { + zeroText: "Not Spam", + zeroDescription: ["Suitable for training Open Assistant."], + oneText: "Spam", + oneDescription: [ + "Seems to be intentionally low-quality or irrelevant", + 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', + "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", + "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", + ], + inverted: true, + }; + case "fails_task": + return { + zeroText: "Follows Instructions", + zeroDescription: [], + oneText: "Fails Task", + oneDescription: ["Fails to follow the correct instruction / task"], + inverted: true, + }; + case "not_appropriate": + return { + zeroText: "Appropriate", + zeroDescription: [], + oneText: "Inappropriate", + oneDescription: ["Inappropriate for customer assistant"], + inverted: true, + }; + case "violence": + return { + zeroText: "Harmless", + zeroDescription: [], + oneText: "Violent", + oneDescription: ["Encourages or fails to discourage violence/abuse/terrorism/self-harm"], + inverted: true, + }; + case "excessive_harm": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Harmful", + oneDescription: [ + "Content likely to cause excessive harm not justifiable in the context", + "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.", + ], + inverted: true, + }; + case "sexual_content": + return { + zeroText: "Non Sexual", + zeroDescription: [], + oneText: "Sexual", + oneDescription: ["Contains sexual content"], + inverted: true, + }; + case "toxicity": + return { + zeroText: "Polite", + zeroDescription: [], + oneText: "Rude", + oneDescription: ["Contains rude, abusive, profane or insulting content"], + inverted: true, + }; + case "moral_judgement": + return { + zeroText: "Non-Judgemental", + zeroDescription: [], + oneText: "Judgemental", + oneDescription: ["Expresses moral judgement"], + inverted: true, + }; + case "political_content": + return { + zeroText: "Apolitical", + zeroDescription: [], + oneText: "Political", + oneDescription: ["Expresses political views"], + inverted: true, + }; + case "humor": + return { + zeroText: "Serious", + zeroDescription: [], + oneText: "Humorous", + oneDescription: ["Contains humorous content including sarcasm"], + inverted: false, + }; + case "hate_speech": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Hateful", + oneDescription: [ + "Content is abusive or threatening and expresses prejudice against a protected characteristic", + "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", + ], + inverted: true, + }; + case "threat": + return { + zeroText: "Safe", + zeroDescription: [], + oneText: "Threatening", + oneDescription: ["Contains a threat against a person or persons"], + inverted: true, + }; + case "misleading": + return { + zeroText: "Accurate", + zeroDescription: [], + oneText: "Misleading", + oneDescription: ["Contains text which is incorrect or misleading"], + inverted: true, + }; + case "helpful": + return { + zeroText: "Unhelful", + zeroDescription: [], + oneText: "Helpful", + oneDescription: ["Completes the task to a high standard"], + inverted: false, + }; + case "creative": + return { + zeroText: "Boring", + zeroDescription: [], + oneText: "Creative", + oneDescription: ["Expresses creativity in responding to the task"], + inverted: false, + }; + case "pii": + return { + zeroText: "Clean", + zeroDescription: [], + oneText: "Contains PII", + oneDescription: ["Contains personally identifing information"], + inverted: false, + }; + case "quality": + return { + zeroText: "Low Quality", + zeroDescription: [], + oneText: "High Quality", + oneDescription: [], + inverted: false, + }; + case "creativity": + return { + zeroText: "Ordinary", + zeroDescription: [], + oneText: "Creative", + oneDescription: [], + inverted: false, + }; + default: + return { + zeroText: `!${label}`, + zeroDescription: [], + oneText: label, + oneDescription: [], + inverted: false, + }; + } }; export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => { @@ -148,7 +194,7 @@ export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: Label return ( {labelIDs.map((labelId, idx) => { - const { zeroText, oneText, zeroDescription, oneDescription, inverted } = label_messages[labelId]; + const { zeroText, oneText, zeroDescription, oneDescription, inverted } = getLabelInfo(labelId); let textA = zeroText; let textB = oneText; From 24f4c0879626c45b9c94c034eca9b06be829c8c0 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Thu, 26 Jan 2023 13:13:46 +0000 Subject: [PATCH 027/101] Refactor task page routes and repetition Remove blank line Lint add blank line Link pre-commit --- .../e2e/tasks/no_tasks_available.cy.ts | 23 ++++++++++ website/public/locales/en/common.json | 5 ++- website/src/components/EmptyState.tsx | 2 +- website/src/components/TaskPage/TaskPage.tsx | 40 ++++++++++++++++++ website/src/lib/api.ts | 3 +- website/src/lib/constants.ts | 42 +++++++++++++++++++ website/src/pages/create/assistant_reply.tsx | 29 ++----------- website/src/pages/create/initial_prompt.tsx | 29 ++----------- website/src/pages/create/user_reply.tsx | 33 +++------------ .../pages/evaluate/rank_assistant_replies.tsx | 29 ++----------- .../pages/evaluate/rank_initial_prompts.tsx | 29 ++----------- .../src/pages/evaluate/rank_user_replies.tsx | 33 +++------------ .../src/pages/label/label_assistant_reply.tsx | 29 ++----------- .../src/pages/label/label_initial_prompt.tsx | 29 ++----------- .../src/pages/label/label_prompter_reply.tsx | 29 ++----------- website/src/pages/tasks/random.tsx | 32 ++------------ website/src/types/Task.ts | 2 +- 17 files changed, 147 insertions(+), 271 deletions(-) create mode 100644 website/cypress/e2e/tasks/no_tasks_available.cy.ts create mode 100644 website/src/components/TaskPage/TaskPage.tsx create mode 100644 website/src/lib/constants.ts diff --git a/website/cypress/e2e/tasks/no_tasks_available.cy.ts b/website/cypress/e2e/tasks/no_tasks_available.cy.ts new file mode 100644 index 00000000..27c33b48 --- /dev/null +++ b/website/cypress/e2e/tasks/no_tasks_available.cy.ts @@ -0,0 +1,23 @@ +describe("no tasks available", () => { + it("displays an empty state when no tasks are available", () => { + cy.signInWithEmail("cypress@example.com"); + cy.intercept( + { + method: "GET", + url: "/api/new_task/prompter_reply", + }, + { + statusCode: 500, + body: { + message: "No tasks of type 'label_prompter_reply' are currently available.", + errorCode: 1006, + httpStatusCode: 503, + }, + } + ).as("newTaskPrompterReply"); + cy.visit("/create/user_reply"); + cy.wait("@newTaskPrompterReply").then(() => { + cy.get('[data-cy="cy-no-tasks"]').should("exist"); + }); + }); +}); diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 8f35eaab..0b0f9d37 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -9,11 +9,12 @@ "docs": "Docs", "github": "GitHub", "legal": "Legal", + "loading": "Loading...", + "more_information": "More Information", "privacy_policy": "Privacy Policy", "report_a_bug": "Report a Bug", "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", - "title": "Open Assistant", - "more_information": "More Information" + "title": "Open Assistant" } diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index b0455774..63d8d3bf 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -15,7 +15,7 @@ export const EmptyState = (props: EmptyStateProps) => { - {props.text} + {props.text} Go back to the dashboard diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx new file mode 100644 index 00000000..9fc26c42 --- /dev/null +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -0,0 +1,40 @@ +import Head from "next/head"; +import { useTranslation } from "next-i18next"; +import { TaskEmptyState } from "src/components/EmptyState"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { Task } from "src/components/Tasks/Task"; +import { TaskInfos } from "src/components/Tasks/TaskTypes"; +import { apiHooksByType, ERROR_CODES } from "src/lib/constants"; +import { getTypeSafei18nKey } from "src/lib/i18n"; +import { TaskType } from "src/types/Task"; + +type TaskPageProps = { + type: TaskType; +}; + +export const TaskPage = ({ type }: TaskPageProps) => { + const { t } = useTranslation(["tasks", "common"]); + const apiHook = apiHooksByType[type]; + const { tasks, isLoading, reset, trigger, error } = apiHook(type); + const taskType = TaskInfos.find((taskType) => taskType.type === type); + + if (isLoading) { + return ; + } + + if (tasks.length === 0 || error?.errorCode === ERROR_CODES.TASK_REQUESTED_TYPE_NOT_AVAILABLE) { + return ; + } + + const task = tasks[0]; + + return ( + <> + + {t(getTypeSafei18nKey(`${taskType.id}.label`))} + + + + + ); +}; diff --git a/website/src/lib/api.ts b/website/src/lib/api.ts index d61016d2..2649daf8 100644 --- a/website/src/lib/api.ts +++ b/website/src/lib/api.ts @@ -17,7 +17,8 @@ export const post = (url: string, { arg: data }) => api.post(url, data).then((re api.interceptors.response.use( (response) => response, (error) => { - throw new OasstError(error.message ?? error, error.error_code, error?.response?.status || -1); + const err = error?.response?.data; + throw new OasstError(err?.message ?? error, err?.errorCode, error?.response?.httpStatusCode || -1); } ); diff --git a/website/src/lib/constants.ts b/website/src/lib/constants.ts new file mode 100644 index 00000000..6269dbf9 --- /dev/null +++ b/website/src/lib/constants.ts @@ -0,0 +1,42 @@ +import { + useCreateAssistantReply, + useCreateInitialPrompt, + useCreatePrompterReply, +} from "src/hooks/tasks/useCreateReply"; +import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { + useLabelAssistantReplyTask, + useLabelInitialPromptTask, + useLabelPrompterReplyTask, +} from "src/hooks/tasks/useLabelingTask"; +import { + useRankAssistantRepliesTask, + useRankInitialPromptsTask, + useRankPrompterRepliesTask, +} from "src/hooks/tasks/useRankReplies"; +import { TaskType } from "src/types/Task"; + +export const ERROR_CODES = { + TASK_REQUESTED_TYPE_NOT_AVAILABLE: 1006, + TASK_INVALID_REQUEST_TYPE: 1000, + TASK_ACK_FAILED: 1001, + TASK_NACK_FAILED: 1002, + TASK_INVALID_RESPONSE_TYPE: 1003, + TASK_INTERACTION_REQUEST_FAILED: 1004, + TASK_GENERATION_FAILED: 1005, + TASK_AVAILABILITY_QUERY_FAILED: 1007, + TASK_MESSAGE_TOO_LONG: 1008, +}; + +export const apiHooksByType = { + [TaskType.random]: useGenericTaskAPI, + [TaskType.assistant_reply]: useCreateAssistantReply, + [TaskType.initial_prompt]: useCreateInitialPrompt, + [TaskType.label_assistant_reply]: useLabelAssistantReplyTask, + [TaskType.label_initial_prompt]: useLabelInitialPromptTask, + [TaskType.label_prompter_reply]: useLabelPrompterReplyTask, + [TaskType.prompter_reply]: useCreatePrompterReply, + [TaskType.rank_assistant_replies]: useRankAssistantRepliesTask, + [TaskType.rank_initial_prompts]: useRankInitialPromptsTask, + [TaskType.rank_prompter_replies]: useRankPrompterRepliesTask, +}; diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index 1c83eb23..0f3095a9 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const AssistantReply = () => { - const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Reply as Assistant - - - - - ); -}; +const AssistantReply = () => ; AssistantReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 639df68f..c73f2e5d 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const InitialPrompt = () => { - const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Initial Prompt - - - - - ); -}; +const InitialPrompt = () => ; InitialPrompt.getLayout = getDashboardLayout; diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 5898439c..39218476 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -1,33 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const UserReply = () => { - const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); +const PrompterReply = () => ; - if (isLoading) { - return ; - } +PrompterReply.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Reply as User - - - - - ); -}; - -UserReply.getLayout = getDashboardLayout; - -export default UserReply; +export default PrompterReply; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index da79d92f..dd4c1df9 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankAssistantReplies = () => { - const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank Assistant Replies - - - - - ); -}; +const RankAssistantReplies = () => ; RankAssistantReplies.getLayout = getDashboardLayout; diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index f23fc0ed..1eb91289 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankInitialPrompts = () => { - const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank Initial Prompts - - - - - ); -}; +const RankInitialPrompts = () => ; RankInitialPrompts.getLayout = getDashboardLayout; diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index cee82b87..a1caba59 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -1,33 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const RankUserReplies = () => { - const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); +const RankPrompterReplies = () => ; - if (isLoading) { - return ; - } +RankPrompterReplies.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Rank User Replies - - - - - ); -}; - -RankUserReplies.getLayout = getDashboardLayout; - -export default RankUserReplies; +export default RankPrompterReplies; diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 07a6cb1c..8be12b41 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelAssistantReply = () => { - const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Assistant Reply - - - - - ); -}; +const LabelAssistantReply = () => ; LabelAssistantReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 8735044f..c5fed344 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelInitialPrompt = () => { - const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Initial Prompt - - - - - ); -}; +const LabelInitialPrompt = () => ; LabelInitialPrompt.getLayout = getDashboardLayout; diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index 17164e11..33e8aba4 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -1,32 +1,9 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; +import { TaskType } from "src/types/Task"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; -const LabelPrompterReply = () => { - const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask(); - - if (isLoading) { - return ; - } - - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Label Prompter Reply - - - - - ); -}; +const LabelPrompterReply = () => ; LabelPrompterReply.getLayout = getDashboardLayout; diff --git a/website/src/pages/tasks/random.tsx b/website/src/pages/tasks/random.tsx index f1c04d2c..cd7ed458 100644 --- a/website/src/pages/tasks/random.tsx +++ b/website/src/pages/tasks/random.tsx @@ -1,34 +1,10 @@ -import Head from "next/head"; -import { TaskEmptyState } from "src/components/EmptyState"; import { getDashboardLayout } from "src/components/Layout"; -import { LoadingScreen } from "src/components/Loading/LoadingScreen"; -import { Task } from "src/components/Tasks/Task"; -import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { TaskPage } from "src/components/TaskPage/TaskPage"; export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import { TaskType } from "src/types/Task"; -const RandomTask = () => { - const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random); +const Random = () => ; - if (isLoading) { - return ; - } +Random.getLayout = getDashboardLayout; - if (tasks.length === 0) { - return ; - } - - return ( - <> - - Random Task - - - - - ); -}; - -RandomTask.getLayout = (page) => getDashboardLayout(page); - -export default RandomTask; +export default Random; diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index 12e37db0..7ae48138 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -1,4 +1,4 @@ -export const enum TaskType { +export enum TaskType { initial_prompt = "initial_prompt", assistant_reply = "assistant_reply", prompter_reply = "prompter_reply", From e6009933db67de3d40cdbac87104858471dd7023 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 27 Jan 2023 10:16:25 +0000 Subject: [PATCH 028/101] Rename const to taskInfo --- website/src/components/TaskPage/TaskPage.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx index 9fc26c42..41d89405 100644 --- a/website/src/components/TaskPage/TaskPage.tsx +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -16,7 +16,7 @@ export const TaskPage = ({ type }: TaskPageProps) => { const { t } = useTranslation(["tasks", "common"]); const apiHook = apiHooksByType[type]; const { tasks, isLoading, reset, trigger, error } = apiHook(type); - const taskType = TaskInfos.find((taskType) => taskType.type === type); + const taskInfo = TaskInfos.find((taskType) => taskType.type === type); if (isLoading) { return ; @@ -31,8 +31,8 @@ export const TaskPage = ({ type }: TaskPageProps) => { return ( <> - {t(getTypeSafei18nKey(`${taskType.id}.label`))} - + {t(getTypeSafei18nKey(`${taskInfo.id}.label`))} + From a0f4449e9f9ef1919859a60c30c459eb4a9ba968 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:23:19 +0100 Subject: [PATCH 029/101] added seed to parameters --- inference/worker/__main__.py | 20 +------------------ .../oasst_shared/schemas/inference.py | 2 +- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index e5c15fb4..190ed788 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -41,25 +41,6 @@ def main( prompt = prefix + "\n".join(messages) + "\nAssistant:" - # TODO: use the seed - # torch.manual_seed(work_request.seed) - # model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ - # 0 - # ]["generated_text"] - # model_output = model_output.strip() - - # # fake streaming - # split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] - # pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] - # for piece in pieces: - # if not piece: - # continue - # if piece.strip() in ("User:", "Assistant:"): - # break - # ws.send(inference.WorkResponsePacket(token=piece).json()) - # time.sleep(0.1) - # ws.send(inference.WorkResponsePacket(is_end=True).json()) - response = requests.post( f"{inference_server_url}/generate_stream", json={ @@ -70,6 +51,7 @@ def main( "top_k": work_request.top_k, "top_p": work_request.top_p, "temperature": work_request.temperature, + "seed": work_request.seed, }, }, stream=True, diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index b50cef9c..91a16b61 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -13,7 +13,7 @@ class WorkRequest(pydantic.BaseModel): conversation: protocol.Conversation = pydantic.Field(..., repr=False) model_name: str = "distilgpt2" max_new_tokens: int = 100 - seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1)) do_sample: bool = True top_k: int = 50 top_p: float = 0.9 From ae5d16f3942d5f365dc231bd027295555e4fd392 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:40:21 +0100 Subject: [PATCH 030/101] added a tmux inference dev setup script --- inference/full-dev-setup.sh | 19 +++++++++++++++++++ inference/worker/__main__.py | 2 ++ 2 files changed, 21 insertions(+) create mode 100755 inference/full-dev-setup.sh diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh new file mode 100755 index 00000000..98a5b173 --- /dev/null +++ b/inference/full-dev-setup.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Creates a tmux window with splits for the individual services + +tmux new-session -d -s "inference-dev-setup" +tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m +tmux split-window -h +tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m +tmux split-window -h +tmux send-keys "cd server" C-m +tmux send-keys "uvicorn main:app --reload" C-m +tmux split-window -h +tmux send-keys "cd worker" C-m +tmux send-keys "python __main__.py" C-m +tmux split-window -h +tmux send-keys "cd text-client" C-m +tmux send-keys "sleep 5" C-m +tmux send-keys "python __main__.py" C-m +tmux attach-session -t "inference-dev-setup" diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 190ed788..96fe164a 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -18,8 +18,10 @@ def main( inference_server_url: str = "http://localhost:8001", ): def on_open(ws: websocket.WebSocket): + logger.info("Connected to backend, sending config...") worker_config = inference.WorkerConfig(model_name=model_name) ws.send(worker_config.json()) + logger.info("Config sent, waiting for work...") def on_message(ws: websocket.WebSocket, message: str): # TODO: what if this comes in, but one is already in progress? From 45a4b09eae04d2a6d3f099da155d576d7efb20f9 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 27 Jan 2023 15:42:01 +0100 Subject: [PATCH 031/101] added setup instructions to readme --- inference/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/inference/README.md b/inference/README.md index bd0272ad..6e1da2c7 100644 --- a/inference/README.md +++ b/inference/README.md @@ -2,7 +2,13 @@ Preliminary implementation of the inference engine for OpenAssistant. -## Development (you'll need multiple terminals) +## Development Variant 1 (you'll need tmux) + +Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait +until the 2nd terminal is ready and says `{"message":"Connected"}` before +entering input into the last terminal. + +## Development Variant 2 (you'll need multiple terminals) Run a redis container (or use the one of the general docker compose file): From 3b04080d7be9f07362a015b5e6b27ff463705bf7 Mon Sep 17 00:00:00 2001 From: James Melvin Ebenezer Date: Fri, 27 Jan 2023 22:36:25 +0530 Subject: [PATCH 032/101] 949_transaction error handling (#950) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: transaction error handling * refactor: retry handling for all decorators as per review comments * fix: raising retry exhausted error * fix: avoid auto refresh on RollBack and review comments * removed refresh_result param from managed_tx_function --------- Co-authored-by: James Melvin Co-authored-by: Andreas Köpf --- backend/oasst_backend/utils/database_utils.py | 168 ++++++++++++------ 1 file changed, 114 insertions(+), 54 deletions(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index fb8bf6c5..5f7a3136 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -7,9 +7,14 @@ from loguru import logger from oasst_backend.config import settings from oasst_backend.database import engine from oasst_shared.exceptions import OasstError, OasstErrorCode -from sqlalchemy.exc import OperationalError +from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation +from sqlalchemy.exc import OperationalError, PendingRollbackError from sqlmodel import Session, SQLModel +""" +Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html +""" + class CommitMode(IntEnum): """ @@ -34,28 +39,46 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s @wraps(f) def wrapped_f(self, *args, **kwargs): try: - for i in range(num_retries): - try: - result = f(self, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) self.db.commit() - elif auto_commit == CommitMode.FLUSH: - self.db.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + self.db.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) self.db.rollback() + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation) + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + self.db.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + self.db.flush() if isinstance(result, SQLModel): self.db.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + elif auto_commit == CommitMode.ROLLBACK: self.db.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + return result except Exception as e: - logger.error("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f @@ -70,28 +93,46 @@ def async_managed_tx_method( @wraps(f) async def wrapped_f(self, *args, **kwargs): try: - for i in range(num_retries): - try: - result = await f(self, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) self.db.commit() - elif auto_commit == CommitMode.FLUSH: - self.db.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + self.db.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) self.db.rollback() + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation) + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + self.db.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + self.db.flush() if isinstance(result, SQLModel): self.db.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + elif auto_commit == CommitMode.ROLLBACK: self.db.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + return result except Exception as e: - logger.exception("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f @@ -107,7 +148,6 @@ def managed_tx_function( auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, session_factory: Callable[..., Session] = default_session_factor, - refresh_result: bool = True, ): """Passes Session object as first argument to wrapped function.""" @@ -115,29 +155,49 @@ def managed_tx_function( @wraps(f) def wrapped_f(*args, **kwargs): try: - for i in range(num_retries): - with session_factory() as session: - try: - result = f(session, *args, **kwargs) - if auto_commit == CommitMode.COMMIT: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + with session_factory() as session: + try: + result = f(session, *args, **kwargs) session.commit() - elif auto_commit == CommitMode.FLUSH: - session.flush() - elif auto_commit == CommitMode.ROLLBACK: + if isinstance(result, SQLModel): + session.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) session.rollback() - if refresh_result and isinstance(result, SQLModel): - session.refresh(result) - return result - except OperationalError: - logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") - session.rollback() - raise OasstError( - "DATABASE_MAX_RETIRES_EXHAUSTED", - error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, - http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, - ) + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, + (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation), + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + session.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + with session_factory() as session: + result = f(session, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + session.flush() + if isinstance(result, SQLModel): + session.refresh(result) + elif auto_commit == CommitMode.ROLLBACK: + session.rollback() + return result except Exception as e: - logger.error("DB Rollback Failure") + logger.info(str(e)) raise e return wrapped_f From ce3b3c7eccd8aadf288b6c67c685c5ec805a97e9 Mon Sep 17 00:00:00 2001 From: notmd Date: Sat, 28 Jan 2023 00:17:31 +0700 Subject: [PATCH 033/101] pass correct param when fetch leaderboard --- website/src/lib/oasst_api_client.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index bd263400..3cc4f4ef 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -144,7 +144,7 @@ export class OasstApiClient { time_frame: LeaderboardTimeFrame, { limit = 20 }: { limit?: number } ): Promise { - return this.get(`/api/v1/leaderboards/${time_frame}`, { limit }); + return this.get(`/api/v1/leaderboards/${time_frame}`, { max_count: limit }); } /** From d16598725664241c554c466ebf30a9b2900af4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 19:43:08 +0100 Subject: [PATCH 034/101] add missing await to async_managed_tx_method --- backend/oasst_backend/utils/database_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 5f7a3136..196178e5 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -98,7 +98,7 @@ def async_managed_tx_method( retry_exhausted = True for i in range(num_retries): try: - result = f(self, *args, **kwargs) + result = await f(self, *args, **kwargs) self.db.commit() if isinstance(result, SQLModel): self.db.refresh(result) From c7692b9049a25ff00c6e39d2854a125cbbc5411d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 27 Jan 2023 19:44:48 +0100 Subject: [PATCH 035/101] add 2nd missing await to async_managed_tx_method --- backend/oasst_backend/utils/database_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 196178e5..5ac25f50 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -123,7 +123,7 @@ def async_managed_tx_method( http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, ) else: - result = f(self, *args, **kwargs) + result = await f(self, *args, **kwargs) if auto_commit == CommitMode.FLUSH: self.db.flush() if isinstance(result, SQLModel): From bf77d4dc60d1a295c02242feec3e8c9ce1e08835 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Fri, 27 Jan 2023 10:51:32 +0000 Subject: [PATCH 036/101] Add types for TaskType to TaskHook Pre-commit Apply better naming to task api hooks Lint --- website/src/components/TaskPage/TaskPage.tsx | 6 ++--- website/src/lib/constants.ts | 3 ++- website/src/types/Hooks.ts | 26 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 website/src/types/Hooks.ts diff --git a/website/src/components/TaskPage/TaskPage.tsx b/website/src/components/TaskPage/TaskPage.tsx index 41d89405..e1ecc79c 100644 --- a/website/src/components/TaskPage/TaskPage.tsx +++ b/website/src/components/TaskPage/TaskPage.tsx @@ -4,7 +4,7 @@ import { TaskEmptyState } from "src/components/EmptyState"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { TaskInfos } from "src/components/Tasks/TaskTypes"; -import { apiHooksByType, ERROR_CODES } from "src/lib/constants"; +import { ERROR_CODES, taskApiHooks } from "src/lib/constants"; import { getTypeSafei18nKey } from "src/lib/i18n"; import { TaskType } from "src/types/Task"; @@ -14,8 +14,8 @@ type TaskPageProps = { export const TaskPage = ({ type }: TaskPageProps) => { const { t } = useTranslation(["tasks", "common"]); - const apiHook = apiHooksByType[type]; - const { tasks, isLoading, reset, trigger, error } = apiHook(type); + const taskApiHook = taskApiHooks[type]; + const { tasks, isLoading, reset, trigger, error } = taskApiHook(type); const taskInfo = TaskInfos.find((taskType) => taskType.type === type); if (isLoading) { diff --git a/website/src/lib/constants.ts b/website/src/lib/constants.ts index 6269dbf9..a260fa28 100644 --- a/website/src/lib/constants.ts +++ b/website/src/lib/constants.ts @@ -14,6 +14,7 @@ import { useRankInitialPromptsTask, useRankPrompterRepliesTask, } from "src/hooks/tasks/useRankReplies"; +import { TaskApiHooks } from "src/types/Hooks"; import { TaskType } from "src/types/Task"; export const ERROR_CODES = { @@ -28,7 +29,7 @@ export const ERROR_CODES = { TASK_MESSAGE_TOO_LONG: 1008, }; -export const apiHooksByType = { +export const taskApiHooks: TaskApiHooks = { [TaskType.random]: useGenericTaskAPI, [TaskType.assistant_reply]: useCreateAssistantReply, [TaskType.initial_prompt]: useCreateInitialPrompt, diff --git a/website/src/types/Hooks.ts b/website/src/types/Hooks.ts new file mode 100644 index 00000000..8fd9aa4f --- /dev/null +++ b/website/src/types/Hooks.ts @@ -0,0 +1,26 @@ +import { MutatorCallback, MutatorOptions } from "swr"; + +import { BaseTask, TaskResponse, TaskType } from "./Task"; + +type ConcreteTaskResponse = TaskResponse; +type TaskError = { errorCode: number; message: string }; + +type Trigger = ( + extraArgument?: unknown, + options?: MutatorOptions +) => Promise; + +type Reset = ( + data?: ConcreteTaskResponse | Promise | MutatorCallback, + opts?: boolean | MutatorOptions +) => Promise; + +type TaskAPIHook = { + tasks: TaskResponse[]; + isLoading: boolean; + error: TaskError; + trigger: Trigger; + reset: Reset; +}; + +export type TaskApiHooks = Record TaskAPIHook>; From 356fd775e93505b32535ca628fa142f987ab0415 Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Sat, 28 Jan 2023 06:52:40 +1100 Subject: [PATCH 037/101] Add emoji reactions and reporting for messages (website) (#952) * website: Move labelling to message ... menu and add reporting and emoji reactions We can add more emoji easily in future, we just need to pick ones that we have consistent icons for. Also added "open in new tab" option so that messages can be navigated to from tasks on mobile. * website: Make new label and report strings translatable. * website: Move report api call to oasst client * small fixes * pre-commit --------- Co-authored-by: AbdBarho --- website/public/locales/en/message.json | 11 ++ website/src/components/Messages.tsx | 2 +- .../src/components/Messages/LabelPopup.tsx | 76 ++++++++ .../Messages/MessageEmojiButton.stories.tsx | 34 ++++ .../Messages/MessageEmojiButton.tsx | 48 +++++ .../Messages/MessageTable.stories.tsx | 23 ++- .../src/components/Messages/MessageTable.tsx | 6 +- .../Messages/MessageTableEntry.stories.tsx | 25 ++- .../components/Messages/MessageTableEntry.tsx | 176 +++++++++++++++--- .../Messages/MessageWithChildren.tsx | 6 +- .../src/components/Messages/ReportPopup.tsx | 56 ++++++ website/src/components/Tasks/EvaluateTask.tsx | 1 - .../components/Tasks/LabelTask/LabelTask.tsx | 2 +- website/src/lib/oasst_api_client.ts | 34 +++- website/src/pages/api/messages/[id]/emoji.ts | 30 +++ website/src/pages/api/messages/[id]/index.ts | 11 +- website/src/pages/api/report.ts | 24 +++ website/src/pages/api/set_label.ts | 6 +- website/src/pages/messages/[id]/index.tsx | 2 +- website/src/types/Conversation.ts | 16 +- website/styles/Theme/colors.tsx | 2 + website/types/i18next.d.ts | 4 +- 22 files changed, 541 insertions(+), 54 deletions(-) create mode 100644 website/public/locales/en/message.json create mode 100644 website/src/components/Messages/LabelPopup.tsx create mode 100644 website/src/components/Messages/MessageEmojiButton.stories.tsx create mode 100644 website/src/components/Messages/MessageEmojiButton.tsx create mode 100644 website/src/components/Messages/ReportPopup.tsx create mode 100644 website/src/pages/api/messages/[id]/emoji.ts create mode 100644 website/src/pages/api/report.ts diff --git a/website/public/locales/en/message.json b/website/public/locales/en/message.json new file mode 100644 index 00000000..45ea04a1 --- /dev/null +++ b/website/public/locales/en/message.json @@ -0,0 +1,11 @@ +{ + "reactions": "Reactions", + "label_action": "Label", + "label_title": "Label", + "submit_labels": "Submit", + "open_new_tab_action": "Open in new tab", + "report_title": "Report", + "report_action": "Report", + "report_placeholder": "Why should this message be reviewed?", + "send_report": "Send" +} diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index 0934c5af..c9d77e3c 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -20,7 +20,7 @@ export const Messages = ({ messages }: MessagesProps) => { return {items}; }; -export const MessageView = forwardRef((message: Message, ref) => { +export const MessageView = forwardRef, "div">((message: Partial, ref) => { const { colorMode } = useColorMode(); const bgColor = useMemo(() => { diff --git a/website/src/components/Messages/LabelPopup.tsx b/website/src/components/Messages/LabelPopup.tsx new file mode 100644 index 00000000..b2b95278 --- /dev/null +++ b/website/src/components/Messages/LabelPopup.tsx @@ -0,0 +1,76 @@ +import { + Button, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, +} from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { useState } from "react"; +import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; +import { get, post } from "src/lib/api"; +import useSWRImmutable from "swr/immutable"; +import useSWRMutation from "swr/mutation"; + +interface LabelMessagePopupProps { + messageId: string; + show: boolean; + onClose: () => void; +} + +interface Label { + name: string; + display_text: string; + help_text: string; +} + +interface ValidLabelsResponse { + valid_labels: Label[]; +} + +export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => { + const { t } = useTranslation("message"); + const { data: response } = useSWRImmutable("/api/valid_labels", get); + const valid_labels = response?.valid_labels ?? []; + const [values, setValues] = useState(null); + + const { trigger: setLabels } = useSWRMutation("/api/set_label", post); + + const submit = () => { + const label_map: Map = new Map(); + console.assert(valid_labels.length === values.length); + values.forEach((value, idx) => { + if (value !== null) { + label_map.set(valid_labels[idx].name, value); + } + }); + setLabels({ + message_id: messageId, + label_map: Object.fromEntries(label_map), + }); + + setValues(null); + onClose(); + }; + + return ( + + + + {t("label_title")} + + + name)} onChange={setValues} /> + + + + + + + ); +}; diff --git a/website/src/components/Messages/MessageEmojiButton.stories.tsx b/website/src/components/Messages/MessageEmojiButton.stories.tsx new file mode 100644 index 00000000..d74836d5 --- /dev/null +++ b/website/src/components/Messages/MessageEmojiButton.stories.tsx @@ -0,0 +1,34 @@ +import React from "react"; + +import { MessageEmojiButton } from "./MessageEmojiButton"; + +// eslint-disable-next-line import/no-anonymous-default-export +export default { + title: "Messages/MessageEmojiButton", + component: MessageEmojiButton, +}; + +const Template = ({ emoji, count, checked }: { emoji: string; count: number; checked?: boolean }) => { + return ; +}; + +export const Default = Template.bind({}); +Default.args = { + emoji: "+1", + count: 7, + checked: false, +}; + +export const BigNumber = Template.bind({}); +BigNumber.args = { + emoji: "+1", + count: 999, + checked: false, +}; + +export const Checked = Template.bind({}); +Checked.args = { + emoji: "+1", + count: 2, + checked: true, +}; diff --git a/website/src/components/Messages/MessageEmojiButton.tsx b/website/src/components/Messages/MessageEmojiButton.tsx new file mode 100644 index 00000000..8b5c9ff7 --- /dev/null +++ b/website/src/components/Messages/MessageEmojiButton.tsx @@ -0,0 +1,48 @@ +import { Button } from "@chakra-ui/react"; +import { BoxSelect, Flag, LucideProps, ThumbsDown, ThumbsUp } from "lucide-react"; +import { ReactElement } from "react"; +import { MessageEmoji } from "src/types/Conversation"; + +type EmojiIconPurpose = "MINI_BUTTON" | "NORMAL"; + +const defaultIconProps: (purpose: EmojiIconPurpose) => LucideProps = (purpose: EmojiIconPurpose) => { + if (purpose === "MINI_BUTTON") return { height: "1em" }; + return {}; +}; + +export const getEmojiIcon = (name: string, purpose: EmojiIconPurpose): ReactElement => { + switch (name) { + case "+1": + return ; + case "-1": + return ; + case "flag": + case "red_flag": + return ; + default: + return ; + } +}; + +interface MessageEmojiButtonProps { + emoji: MessageEmoji; + checked?: boolean; + onClick: () => void; +} + +export const MessageEmojiButton = ({ emoji, checked, onClick }: MessageEmojiButtonProps) => { + return ( + + ); +}; diff --git a/website/src/components/Messages/MessageTable.stories.tsx b/website/src/components/Messages/MessageTable.stories.tsx index 6f383b01..bc03aed1 100644 --- a/website/src/components/Messages/MessageTable.stories.tsx +++ b/website/src/components/Messages/MessageTable.stories.tsx @@ -29,18 +29,24 @@ Default.args = { is_assistant: true, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?", is_assistant: false, id: "", frontend_message_id: "", + emojis: { "-1": 11, red_flag: 2 }, + user_emojis: [], }, { text: "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: true, @@ -50,12 +56,21 @@ Default.args = { export const Conversation = Template.bind({}); Conversation.args = { messages: [ - { text: "Hello! How can I help you?", is_assistant: true, id: "", frontend_message_id: "" }, + { + text: "Hello! How can I help you?", + is_assistant: true, + id: "", + frontend_message_id: "", + emojis: {}, + user_emojis: [], + }, { text: "Who were the 8 presidents before George Washington?", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: false, @@ -70,18 +85,24 @@ LongText.args = { is_assistant: true, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "Yes, I think they can be helpful when the child misbehaves, but they should be used with a little bit of compassion and understanding that it\u2019s not the natural state of things to have an adult yelling at them. Time outs are also often used without letting the child know how they\u2019re getting out of the time out, which can make it feel arbitrary or like a punishment, rather than a consequence for something they did. It\u2019s really easy for adults to do this kind of thing unconsciously. It\u2019s easy to get caught up in the notion that \u201cThey\u2019re in time out, and that\u2019s the end of it!\u201d but kids can be pretty imaginative, and they can use their own creativity to make their way out of time outs. A compassionate time out ends when the child shows a sign of understanding what they\u2019ve done wrong, and are ready to begin again. That way the child knows they\u2019re learning, and that the parent is seeing them as an intelligent person, even if they sometimes mess up. You can still use the other techniques you were using to be tough when necessary, but using a compassionate approach will let you use them without actually using them!", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, { text: "No. The USA was founded by a Puritan group of Protestants, but it didn\u2019t adopt the religion of the Puritans until much later, and it was always a secular state. The Puritans observed the Sabbath on Sunday, and the Puritans only had a small influence in the early history of the USA. It\u2019s difficult to trace the origins of closing stores on Sunday, but one early and short-lived attempt at forcing the Sabbath on people in the 1800s was motivated by the Protestant ideal that people should spend Sunday focusing on spiritual activities. By the mid-1800s, when the Sunday closing law was made, there was not a lot of pressure from that standpoint, but the church had begun to advocate for Sunday closing laws as a way of counteracting the negative effects of industrialization on the day of rest. Even after that shift, closing stores on Sunday was not always possible, since the religious Sunday was not always chosen for observance. And as industrialization accelerated and mechanization made it possible to operate stores on Sunday, the law was not enforced as much as people liked. The day of rest was also being violated by stores that stayed open all day on Sunday, so closing stores on Sundays became an effort to protect the Sabbath for all citizens.", is_assistant: false, id: "", frontend_message_id: "", + emojis: {}, + user_emojis: [], }, ], enableLink: true, diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index acf92e05..2d39f346 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -11,11 +11,11 @@ interface MessageTableProps { export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) { return ( - {messages.map((item, idx) => ( + {messages.map((message, idx) => ( ))} diff --git a/website/src/components/Messages/MessageTableEntry.stories.tsx b/website/src/components/Messages/MessageTableEntry.stories.tsx index b6071dd7..3550d00f 100644 --- a/website/src/components/Messages/MessageTableEntry.stories.tsx +++ b/website/src/components/Messages/MessageTableEntry.stories.tsx @@ -1,4 +1,5 @@ import React from "react"; +import { Message } from "src/types/Conversation"; import { MessageTableEntry } from "./MessageTableEntry"; @@ -8,10 +9,8 @@ export default { component: MessageTableEntry, }; -const Template = ({ text, is_assistant, id, frontend_message_id, enabled, highlight }) => { - return ( - - ); +const Template = ({ enabled, highlight, ...message }) => { + return ; }; export const Default = Template.bind({}); @@ -22,6 +21,8 @@ Default.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], }; export const Asistant = Template.bind({}); @@ -32,6 +33,8 @@ Asistant.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], }; export const LongText = Template.bind({}); @@ -42,4 +45,18 @@ LongText.args = { frontend_message_id: "", enabled: true, highlight: false, + emojis: {}, + user_emojis: [], +}; + +export const WithEmoji = Template.bind({}); +WithEmoji.args = { + text: "As you\u2019ve mentioned, Star Wars has many sequels, prequels, and crossovers. The official list of movies in Star Wars is:", + is_assistant: true, + id: "", + frontend_message_id: "", + enabled: true, + highlight: false, + emojis: { "-1": 5, "+1": 1 }, + user_emojis: ["-1"], }; diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 77202c44..3cde48f6 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,23 +1,47 @@ -import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; +import { + Avatar, + Box, + HStack, + Menu, + MenuButton, + MenuDivider, + MenuGroup, + MenuItem, + MenuList, + SimpleGrid, + useBreakpointValue, + useColorModeValue, + useDisclosure, +} from "@chakra-ui/react"; import { boolean } from "boolean"; +import { ClipboardList, Flag, MessageSquare, MoreHorizontal } from "lucide-react"; import { useRouter } from "next/router"; -import { useCallback, useMemo } from "react"; -import { FlaggableElement } from "src/components/FlaggableElement"; -import { Message } from "src/types/Conversation"; +import { useTranslation } from "next-i18next"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { LabelMessagePopup } from "src/components/Messages/LabelPopup"; +import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton"; +import { ReportPopup } from "src/components/Messages/ReportPopup"; +import { post } from "src/lib/api"; +import { Message, MessageEmojis } from "src/types/Conversation"; import { colors } from "styles/Theme/colors"; +import useSWRMutation from "swr/mutation"; interface MessageTableEntryProps { - item: Message; + message: Message; enabled?: boolean; highlight?: boolean; } -export function MessageTableEntry(props: MessageTableEntryProps) { +export function MessageTableEntry({ message, enabled, highlight }: MessageTableEntryProps) { const router = useRouter(); + const [emojis, setEmojis] = useState({ emojis: {}, user_emojis: [] }); + useEffect(() => { + setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis }); + }, [message.emojis, message.user_emojis]); - const { item } = props; - - const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]); + const goToMessage = useCallback(() => router.push(`/messages/${message.id}`), [router, message.id]); + const { isOpen: reportPopupOpen, onOpen: showReportPopup, onClose: closeReportPopup } = useDisclosure(); + const { isOpen: labelPopupOpen, onOpen: showLabelPopup, onClose: closeLabelPopup } = useDisclosure(); const backgroundColor = useColorModeValue("gray.100", "gray.700"); const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B"); @@ -32,34 +56,124 @@ export function MessageTableEntry(props: MessageTableEntryProps) { borderColor={borderColor} size={inlineAvatar ? "xs" : "sm"} mr={inlineAvatar ? 2 : 0} - name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`} - src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`} + name={`${boolean(message.is_assistant) ? "Assistant" : "User"}`} + src={`${boolean(message.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`} /> ), - [borderColor, inlineAvatar, item.is_assistant] + [borderColor, inlineAvatar, message.is_assistant] ); const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight); + const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, { + onSuccess: setEmojis, + }); + const react = (emoji: string, state: boolean) => { + sendEmojiChange({ op: state ? "add" : "remove", emoji }); + }; + return ( - - - {!inlineAvatar && avatar} - + {!inlineAvatar && avatar} + + {inlineAvatar && avatar} + {message.text} + e.stopPropagation()} > - {inlineAvatar && avatar} - {item.text} - - - + {Object.entries(emojis.emojis).map(([emoji, count]) => ( + react(emoji, !emojis.user_emojis.includes(emoji))} + /> + ))} + + + + + + ); } + +const EmojiMenuItem = ({ + emoji, + checked, + react, +}: { + emoji: string; + checked?: boolean; + react: (emoji: string, state: boolean) => void; +}) => { + const activeColor = useColorModeValue(colors.light.active, colors.dark.active); + + return ( + react(emoji, !checked)} justifyContent="center" color={checked ? activeColor : undefined}> + {getEmojiIcon(emoji, "NORMAL")} + + ); +}; + +const MessageActions = ({ + react, + userEmoji, + onLabel, + onReport, + messageId, +}: { + react: (emoji: string, state: boolean) => void; + userEmoji: string[]; + onLabel: () => void; + onReport: () => void; + messageId: string; +}) => { + const { t } = useTranslation("message"); + + return ( + + + + + + + + {["+1", "-1"].map((emoji) => ( + + ))} + + + + }> + {t("label_action")} + + }> + {t("report_action")} + + + }> + {t("open_new_tab_action")} + + + + ); +}; diff --git a/website/src/components/Messages/MessageWithChildren.tsx b/website/src/components/Messages/MessageWithChildren.tsx index ca29c410..f47cfac5 100644 --- a/website/src/components/Messages/MessageWithChildren.tsx +++ b/website/src/components/Messages/MessageWithChildren.tsx @@ -52,7 +52,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { {isFirst ? "Message" : depth === 1 ? "Children" : "Ancestor"} - + @@ -86,9 +86,9 @@ export function MessageWithChildren(props: MessageWithChildrenProps) { gap="4" shadow="base" > - {children.map((item, idx) => ( + {children.map((message, idx) => ( - + ))} diff --git a/website/src/components/Messages/ReportPopup.tsx b/website/src/components/Messages/ReportPopup.tsx new file mode 100644 index 00000000..67a2ea1b --- /dev/null +++ b/website/src/components/Messages/ReportPopup.tsx @@ -0,0 +1,56 @@ +import { + Button, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Textarea, +} from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { useState } from "react"; +import { post } from "src/lib/api"; +import useSWRMutation from "swr/mutation"; + +interface ReportPopupProps { + messageId: string; + show: boolean; + onClose: () => void; +} + +export const ReportPopup = ({ messageId, show, onClose }: ReportPopupProps) => { + const { t } = useTranslation("message"); + const [text, setText] = useState(""); + const { trigger } = useSWRMutation("/api/report", post); + + const submit = () => { + trigger({ + message_id: messageId, + text, + }); + + setText(""); + onClose(); + }; + + return ( + + + + {t("report_title")} + + +