From 55a4cf1fd096f49583365bfac31e6c62636204a6 Mon Sep 17 00:00:00 2001 From: d_auras Date: Thu, 5 Jan 2023 11:12:11 +0100 Subject: [PATCH 01/69] finally hooked up redis insights to redis --- ansible/dev.yaml | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ansible/dev.yaml b/ansible/dev.yaml index d022ba3c..04252c0d 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -10,6 +10,42 @@ state: present driver: bridge + - name: Set up Redis + community.docker.docker_container: + name: redis + #name: oasst-redis + image: redis + state: started + restart_policy: always + network_mode: oasst + ports: + - 6379:6379 + healthcheck: + test: ["CMD-SHELL", "redis-cli ping | grep PONG"] + interval: 2s + timeout: 2s + retries: 10 + command: redis-server /usr/local/etc/redis/redis.conf + volumes: + - "./redis.conf:/usr/local/etc/redis/redis.conf" + + - name: Set up Redis Insights + community.docker.docker_container: + name: redis-insights + #name: oasst-redis-insights + image: redislabs/redisinsight:latest + #command: pip install redis-cli python entry.pyc + state: started + restart_policy: always + network_mode: oasst + ports: + - 8001:8001 + env: + REDIS_URL: redis://redis:6379 + REDIS_HOST: redis + #environment: + #- REDIS_URL=redis://redis:6379 + - name: Create postgres containers community.docker.docker_container: name: "{{ item.name }}" From 8921b4f8dd9d2157f5e69b3bd1b231d08c0baec3 Mon Sep 17 00:00:00 2001 From: d_auras Date: Thu, 5 Jan 2023 11:31:13 +0100 Subject: [PATCH 02/69] added test files and README, ready for pull request --- ansible/README.md | 3 +++ ansible/dev.yaml | 12 ++---------- ansible/test.inventory.ini | 2 ++ 3 files changed, 7 insertions(+), 10 deletions(-) create mode 100644 ansible/README.md create mode 100644 ansible/test.inventory.ini diff --git a/ansible/README.md b/ansible/README.md new file mode 100644 index 00000000..b15bf75c --- /dev/null +++ b/ansible/README.md @@ -0,0 +1,3 @@ +To test the ansible playbook on localhost run ```ansible-playbook -i test.inventory.ini dev.yaml```. +Point Redis Insights to the Redis database by visiting localhost:8001 in a browser and select "I already have a database" followed by "Connect to a Redis Database". +For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```. diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 04252c0d..c9195966 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -12,8 +12,7 @@ - name: Set up Redis community.docker.docker_container: - name: redis - #name: oasst-redis + name: oasst-redis image: redis state: started restart_policy: always @@ -31,20 +30,13 @@ - name: Set up Redis Insights community.docker.docker_container: - name: redis-insights - #name: oasst-redis-insights + name: oasst-redis-insights image: redislabs/redisinsight:latest - #command: pip install redis-cli python entry.pyc state: started restart_policy: always network_mode: oasst ports: - 8001:8001 - env: - REDIS_URL: redis://redis:6379 - REDIS_HOST: redis - #environment: - #- REDIS_URL=redis://redis:6379 - name: Create postgres containers community.docker.docker_container: diff --git a/ansible/test.inventory.ini b/ansible/test.inventory.ini new file mode 100644 index 00000000..bfe6d93f --- /dev/null +++ b/ansible/test.inventory.ini @@ -0,0 +1,2 @@ +[test] +dev ansible_connection=local From 43227b2cdcd241534355a5c8b075966bcc0d2b05 Mon Sep 17 00:00:00 2001 From: d_auras Date: Thu, 5 Jan 2023 11:42:02 +0100 Subject: [PATCH 03/69] fixed line breaks in README --- ansible/README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ansible/README.md b/ansible/README.md index b15bf75c..b8bd2a48 100644 --- a/ansible/README.md +++ b/ansible/README.md @@ -1,3 +1,4 @@ -To test the ansible playbook on localhost run ```ansible-playbook -i test.inventory.ini dev.yaml```. -Point Redis Insights to the Redis database by visiting localhost:8001 in a browser and select "I already have a database" followed by "Connect to a Redis Database". -For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```. +To test the ansible playbook on localhost run ```ansible-playbook -i test.inventory.ini dev.yaml```.\ +In case you're missing the ansible docker depencency install it with ```ansible-galaxy collection install community.docker```.\ +Point Redis Insights to the Redis database by visiting localhost:8001 in a browser and select "I already have a database" followed by "Connect to a Redis Database".\ +For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```.\ From aa09245f7346b02bb13e6c0767993ecdd1f18393 Mon Sep 17 00:00:00 2001 From: d_auras Date: Thu, 5 Jan 2023 11:44:39 +0100 Subject: [PATCH 04/69] fixed last trailing line break in README --- ansible/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ansible/README.md b/ansible/README.md index b8bd2a48..370f9e1e 100644 --- a/ansible/README.md +++ b/ansible/README.md @@ -1,4 +1,4 @@ To test the ansible playbook on localhost run ```ansible-playbook -i test.inventory.ini dev.yaml```.\ In case you're missing the ansible docker depencency install it with ```ansible-galaxy collection install community.docker```.\ Point Redis Insights to the Redis database by visiting localhost:8001 in a browser and select "I already have a database" followed by "Connect to a Redis Database".\ -For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```.\ +For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```. From d379193bed67db15fd9c56929320cc0fc1c7eccd Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Thu, 5 Jan 2023 16:15:20 +0100 Subject: [PATCH 05/69] add REDIS_HOST environment variable to backend to comply fully with `docker-compose.yaml` --- ansible/dev.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/ansible/dev.yaml b/ansible/dev.yaml index ca2a11d9..eea49a3e 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -79,6 +79,7 @@ network_mode: oasst env: POSTGRES_HOST: oasst-postgres + REDIS_HOST: oasst-redis DEBUG_ALLOW_ANY_API_KEY: "true" DEBUG_USE_SEED_DATA: "true" MAX_WORKERS: "1" From a40e6ec31d17f988d7154b663263de5d15d4d740 Mon Sep 17 00:00:00 2001 From: d_auras Date: Thu, 5 Jan 2023 16:25:58 +0100 Subject: [PATCH 06/69] ran pre-commit again --- ansible/README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ansible/README.md b/ansible/README.md index 370f9e1e..2ab1943e 100644 --- a/ansible/README.md +++ b/ansible/README.md @@ -1,4 +1,7 @@ -To test the ansible playbook on localhost run ```ansible-playbook -i test.inventory.ini dev.yaml```.\ -In case you're missing the ansible docker depencency install it with ```ansible-galaxy collection install community.docker```.\ -Point Redis Insights to the Redis database by visiting localhost:8001 in a browser and select "I already have a database" followed by "Connect to a Redis Database".\ -For host, port and name fill in ```oasst-redis```, ```6379``` and ```redis```. +To test the ansible playbook on localhost run +`ansible-playbook -i test.inventory.ini dev.yaml`.\ +In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\ +Point Redis Insights to the Redis database by visiting localhost:8001 in a +browser and select "I already have a database" followed by "Connect to a Redis +Database".\ +For host, port and name fill in `oasst-redis`, `6379` and `redis`. From 2870524aa780f55b2bd03d7d1aad1516864f187b Mon Sep 17 00:00:00 2001 From: d_auras Date: Fri, 6 Jan 2023 20:03:15 +0100 Subject: [PATCH 07/69] ansible copies redis.conf to managed node now --- ansible/dev.yaml | 5 +++++ ansible/redis.conf | 2 ++ ansible/remote.inventory.ini | 3 +++ 3 files changed, 10 insertions(+) create mode 100644 ansible/redis.conf create mode 100644 ansible/remote.inventory.ini diff --git a/ansible/dev.yaml b/ansible/dev.yaml index eea49a3e..90f7a85a 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -10,6 +10,11 @@ state: present driver: bridge + - name: Copy redis.conf to managed node + ansible.builtin.copy: + src: ./redis.conf + dest: ./redis.conf + - name: Set up Redis community.docker.docker_container: name: oasst-redis diff --git a/ansible/redis.conf b/ansible/redis.conf new file mode 100644 index 00000000..58da1e05 --- /dev/null +++ b/ansible/redis.conf @@ -0,0 +1,2 @@ +maxmemory 100mb +maxmemory-policy allkeys-lru diff --git a/ansible/remote.inventory.ini b/ansible/remote.inventory.ini new file mode 100644 index 00000000..a3afb2df --- /dev/null +++ b/ansible/remote.inventory.ini @@ -0,0 +1,3 @@ +[dev] +;list your remote hosts here +ubuntu-ssh From c1dab2d213e4c6684edb78e1b0b6de5f672ea64a Mon Sep 17 00:00:00 2001 From: d_auras Date: Fri, 6 Jan 2023 21:46:51 +0100 Subject: [PATCH 08/69] deleted my remote.inventory.ini --- ansible/remote.inventory.ini | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 ansible/remote.inventory.ini diff --git a/ansible/remote.inventory.ini b/ansible/remote.inventory.ini deleted file mode 100644 index a3afb2df..00000000 --- a/ansible/remote.inventory.ini +++ /dev/null @@ -1,3 +0,0 @@ -[dev] -;list your remote hosts here -ubuntu-ssh From 3625f39948937b9f5e46c666578048b6304eb6fe Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 7 Jan 2023 01:36:27 +0000 Subject: [PATCH 09/69] [feature] Add GPTJ synthetic dataset, fix reference removal regex for webgpt --- model/reward/instructor/rank_datasets.py | 44 ++++++++++++ model/reward/instructor/tests/test_dataset.py | 16 +++-- model/reward/instructor/trainer.py | 71 ++++--------------- model/reward/instructor/utils.py | 30 +++++++- 4 files changed, 96 insertions(+), 65 deletions(-) diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py index 5e7da948..a638c0d1 100644 --- a/model/reward/instructor/rank_datasets.py +++ b/model/reward/instructor/rank_datasets.py @@ -193,3 +193,47 @@ class HFSummary(Dataset): valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample) # optimize the format later return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx] + + +class HFDataset(Dataset): + """ + This is a base huggingface dataset which written to support the + simplest pos-neg pair format + + we should do something like this for supervised datasets + """ + + def __init__( + self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None + ) -> None: + super().__init__() + dataset = load_dataset(dataset_name, subset) + if split is not None: + dataset = dataset[split] + + self.questions = {} + self.index2question = {} + for row in dataset: + question = row[question_field].strip() + pos = row[pos_answer_field] + neg = row[neg_answer_field] + if question not in self.index2question: + self.index2question[len(self.index2question)] = question + + if question not in self.questions: + self.questions[question] = [] + self.questions[question].append((pos.strip(), neg.strip())) + + def __len__(self): + return len(self.index2question) + + def __getitem__(self, index): + question = self.index2question[index] + rows = self.questions[question] + # optimize the format later + return question, rows + + +class GPTJSynthetic(HFDataset): + def __init__(self) -> None: + super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train") diff --git a/model/reward/instructor/tests/test_dataset.py b/model/reward/instructor/tests/test_dataset.py index 746a3c1e..832aace3 100644 --- a/model/reward/instructor/tests/test_dataset.py +++ b/model/reward/instructor/tests/test_dataset.py @@ -1,5 +1,5 @@ from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality -from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT +from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -25,7 +25,7 @@ def test_webgpt(): print(batch["input_ids"].shape) -def test_hf_quality(): +def test_hf_summary_quality(): tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200) @@ -35,6 +35,12 @@ def test_hf_quality(): print(batch["input_ids"].shape) -if __name__ == "__main__": - test_hf_quality() - # test_webgpt() +def test_gptj_dataset(): + dataset = GPTJSynthetic() + tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large") + collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024) + + print(len(dataset)) + dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32) + for batch in dataloader: + batch["input_ids"].shape diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py index f9266d70..ee330377 100644 --- a/model/reward/instructor/trainer.py +++ b/model/reward/instructor/trainer.py @@ -1,26 +1,15 @@ import os from argparse import ArgumentParser -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import evaluate import numpy as np import torch from models import RankGenModel -from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT +from rank_datasets import DataCollatorForPairRank, RankGenCollator from torch import nn -from torch.utils.data import ConcatDataset, Dataset -from transformers import ( - AutoModelForSequenceClassification, - DataCollator, - EvalPrediction, - PreTrainedModel, - PreTrainedTokenizerBase, - Trainer, - TrainerCallback, - TrainingArguments, -) -from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset +from transformers import AutoModelForSequenceClassification, PreTrainedModel, Trainer, TrainingArguments +from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer os.environ["WANDB_PROJECT"] = "reward-model" @@ -29,11 +18,6 @@ parser = ArgumentParser() parser.add_argument("config", type=str) -@dataclass -class CustomTrainingArguments(TrainingArguments): - loss_function: str = "rank" - - def compute_metrics(eval_pred): predictions, _ = eval_pred predictions = np.argmax(predictions, axis=1) @@ -57,31 +41,12 @@ class RankTrainer(Trainer): model: Union[PreTrainedModel, nn.Module] = None, model_name: str = None, args: Optional[TrainingArguments] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - model_init: Callable[[], PreTrainedModel] = None, - compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + loss_function: str = "rank", + **kwargs, ): - super().__init__( - model, - args, - data_collator, - train_dataset, - eval_dataset, - tokenizer, - model_init, - compute_metrics, - callbacks, - optimizers, - preprocess_logits_for_metrics, - ) - self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss() - self.loss_function = args.loss_function + super().__init__(model, args, **kwargs) + self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss() + self.loss_function = loss_function self.model_name = model_name def compute_loss(self, model, inputs, return_outputs=False): @@ -160,7 +125,7 @@ if __name__ == "__main__": params = sum([np.prod(p.size()) for p in model_parameters]) print("Number of trainable : {}M".format(int(params / 1e6))) - args = CustomTrainingArguments( + args = TrainingArguments( output_dir=f"{model_name}-finetuned", num_train_epochs=training_conf["num_train_epochs"], warmup_steps=500, @@ -181,22 +146,9 @@ if __name__ == "__main__": save_steps=1000, report_to="local", ) - train_datasets, evals = [], {} - if "webgpt" in training_conf["datasets"]: - web_dataset = WebGPT() - train, eval = train_val_dataset(web_dataset) - train_datasets.append(train) - evals["webgpt"] = eval - if "hfsummary" in training_conf["datasets"]: - sum_train = HFSummary(split="train") - train_datasets.append(sum_train) - sum_eval = HFSummary(split="valid1") - assert len(sum_eval) > 0 - evals["hfsummary"] = sum_eval - train = ConcatDataset(train_datasets) tokenizer = get_tokenizer(training_conf["tokenizer_name"]) - + train, evals = get_datasets(training_conf["datasets"]) if "rankgen" in model_name: collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"]) else: @@ -206,6 +158,7 @@ if __name__ == "__main__": model=model, model_name=model_name, args=args, + loss_function=training_conf["loss"], train_dataset=train, eval_dataset=eval, data_collator=collate_fn, diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py index fe52c2ef..a6f3da4e 100644 --- a/model/reward/instructor/utils.py +++ b/model/reward/instructor/utils.py @@ -1,11 +1,13 @@ import re +from typing import AnyStr, List import yaml from sklearn.model_selection import train_test_split from torch.utils.data import Subset from transformers import AutoTokenizer, T5Tokenizer -re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]") +# @agoryuno contributed this +re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") def webgpt_return_format(row): @@ -97,6 +99,32 @@ def argument_parsing(parser): return params +def get_datasets(dataset_list: List[AnyStr]): + from rank_datasets import GPTJSynthetic, HFSummary, WebGPT + from torch.utils.data import ConcatDataset + + train_datasets, evals = [], {} + for dataset_name in dataset_list: + if "webgpt" == dataset_name: + web_dataset = WebGPT() + train, eval = train_val_dataset(web_dataset, 0.2) + train_datasets.append(train) + evals["webgpt"] = eval + elif "hfsummary" == dataset_name: + sum_train = HFSummary(split="train") + train_datasets.append(sum_train) + sum_eval = HFSummary(split="valid1") + assert len(sum_eval) > 0 + evals["hfsummary"] = sum_eval + elif "gptsynthetic" == dataset_name: + dataset = GPTJSynthetic() + train, eval = train_val_dataset(dataset, 0.1) + train_datasets.append(train) + evals["gptsynthetic"] = eval + train = ConcatDataset(train_datasets) + return train, evals + + if __name__ == "__main__": from transformers import AutoModelForSequenceClassification From f57041cddab90c2fb2bb72d02eb79c0ee097396f Mon Sep 17 00:00:00 2001 From: klotske Date: Sat, 7 Jan 2023 15:08:38 +0300 Subject: [PATCH 10/69] Added basic Admin dashboard layout --- website/src/components/Layout.tsx | 20 +++++++++++++++++++- website/src/pages/admin/index.tsx | 4 ++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 1faefcc0..c4450f58 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -1,7 +1,7 @@ // https://nextjs.org/docs/basic-features/layouts import type { NextPage } from "next"; -import { FiLayout, FiMessageSquare } from "react-icons/fi"; +import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi"; import { Header } from "src/components/Header"; import { Footer } from "./Footer"; @@ -51,4 +51,22 @@ export const getDashboardLayout = (page: React.ReactElement) => ( ); +export const getAdminLayout = (page: React.ReactElement) => ( +
+
+ + {page} + +
+); + export const noLayout = (page: React.ReactElement) => page; diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index 60d61903..114eee3e 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -2,7 +2,7 @@ import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; import { useEffect } from "react"; -import { getTransparentHeaderLayout } from "src/components/Layout"; +import { getAdminLayout } from "src/components/Layout"; import UsersCell from "src/components/UsersCell"; /** @@ -44,6 +44,6 @@ const AdminIndex = () => { ); }; -AdminIndex.getLayout = getTransparentHeaderLayout; +AdminIndex.getLayout = getAdminLayout; export default AdminIndex; From 53814d77abc06a639bb7d8c7ea9ffecb6c58bc86 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sat, 7 Jan 2023 11:25:25 +0100 Subject: [PATCH 11/69] Label Initial Prompt --- website/.eslintrc.json | 3 +- .../src/components/Dashboard/TaskOption.tsx | 6 +- website/src/components/FlaggableElement.tsx | 3 +- website/src/components/Messages.tsx | 34 +++--- website/src/components/Tasks/TaskTypes.tsx | 15 ++- website/src/hooks/useLabelingTask.ts | 52 ++++++++ website/src/lib/oasst_api_client.ts | 2 +- website/src/pages/api/update_task.ts | 7 +- .../src/pages/label/label_initial_prompt.tsx | 113 ++++++++++++++++++ 9 files changed, 210 insertions(+), 25 deletions(-) create mode 100644 website/src/hooks/useLabelingTask.ts create mode 100644 website/src/pages/label/label_initial_prompt.tsx diff --git a/website/.eslintrc.json b/website/.eslintrc.json index 04b5d542..690c055c 100644 --- a/website/.eslintrc.json +++ b/website/.eslintrc.json @@ -8,7 +8,8 @@ "rules": { "unused-imports/no-unused-imports": "warn", "simple-import-sort/imports": "warn", - "simple-import-sort/exports": "warn" + "simple-import-sort/exports": "warn", + "eqeqeq": "warn" }, "plugins": ["simple-import-sort", "unused-imports"] } diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index 5e6ceb2f..1c070e17 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -3,7 +3,7 @@ import Link from "next/link"; import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes"; -const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate]; +const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label]; export const TaskOption = () => { const backgroundColor = useColorModeValue("white", "gray.700"); @@ -12,9 +12,9 @@ export const TaskOption = () => { {displayTaskCategories.map((category, categoryIndex) => (
- {TaskCategory[category]} + {category} - {TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => ( + {TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => ( { ); }; -function FlagCheckbox(props: { + +export function FlagCheckbox(props: { option: textFlagLabels; idx: number; checkboxValues: boolean[]; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index d3d7b3b8..7b69bc50 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,5 +1,6 @@ import { Grid } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; +import { useMemo } from "react"; import { FlaggableElement } from "./FlaggableElement"; @@ -8,29 +9,30 @@ export interface Message { is_assistant: boolean; } -const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => { - if (colorMode === "light") { - return isAssistant ? "bg-slate-800" : "bg-sky-900"; - } else { - return isAssistant ? "bg-black" : "bg-sky-900"; - } -}; - export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => { - const { colorMode } = useColorMode(); + const items = messages.map((messageProps: Message, i: number) => { + const { text } = messageProps; - const items = messages.map(({ text, is_assistant }: Message, i: number) => { return ( -
- {text} -
+
); }); // Maybe also show a legend of the colors? return {items}; }; + +export const MessageView = ({ is_assistant, text }: Message) => { + const { colorMode } = useColorMode(); + + const bgColor = useMemo(() => { + if (colorMode === "light") { + return is_assistant ? "bg-slate-800" : "bg-sky-900"; + } else { + return is_assistant ? "bg-black" : "bg-sky-900"; + } + }, [colorMode, is_assistant]); + + return
{text}
; +}; diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 413a1e16..7cec2177 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -1,9 +1,11 @@ export enum TaskCategory { - Create, - Evaluate, + Create = "Create", + Evaluate = "Evaluate", + Label = "Label", } export const TaskTypes = [ + // create { label: "Create Initial Prompts", desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.", @@ -31,6 +33,7 @@ export const TaskTypes = [ overview: "Given the following conversation, provide an adequate reply", instruction: "Provide the assistant`s reply", }, + // evaluate { label: "Rank User Replies", category: TaskCategory.Evaluate, @@ -52,4 +55,12 @@ export const TaskTypes = [ pathname: "/evaluate/rank_initial_prompts", type: "rank_initial_prompts", }, + // label + { + label: "Label Initial Prompt", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_initial_prompt", + type: "label_initial_prompt", + }, ]; diff --git a/website/src/hooks/useLabelingTask.ts b/website/src/hooks/useLabelingTask.ts new file mode 100644 index 00000000..872909b7 --- /dev/null +++ b/website/src/hooks/useLabelingTask.ts @@ -0,0 +1,52 @@ +import { useEffect, useState } from "react"; +import fetcher from "src/lib/fetcher"; +import poster from "src/lib/poster"; +import useSWRImmutable from "swr/immutable"; +import useSWRMutation from "swr/mutation"; + +// TODO: type & centralize types for all tasks +interface TaskResponse { + id: string; + userId: string; + task: TaskType; +} + +export interface LabelInitialPromptTask { + id: string; + message_id: string; + prompt: string; + type: string; + valid_labels: string[]; +} + +export type LabelInitialPromptTaskResponse = TaskResponse; + +export const useLabelingTask = ({ taskApiEndpoint }: { taskApiEndpoint: "label_initial_prompt" }) => { + type ConcreteTaskResponse = TaskResponse; + + const [tasks, setTasks] = useState>([]); + + const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, fetcher, { + onSuccess: (data: ConcreteTaskResponse) => { + setTasks([data]); + }, + }); + + useEffect(() => { + if (tasks.length === 0 && !isLoading && !error) { + mutate(); + } + }, [tasks, isLoading, mutate, error]); + + const { trigger } = useSWRMutation("/api/update_task", poster, { + onSuccess: async (reply) => { + const newTask: ConcreteTaskResponse = await reply.json(); + setTasks((oldTasks) => [...oldTasks, newTask]); + }, + }); + + const submit = (id: string, message_id: string, text: string, labels: Record) => + trigger({ id, update_type: "text_labels", content: { labels, text, message_id } }); + + return { tasks, isLoading, submit, error, reset: mutate }; +}; diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 4cf891e1..86854c21 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -42,7 +42,7 @@ export class OasstApiClient { } catch (e) { throw new OasstError(errorText, 0, resp.status); } - throw new OasstError(error.message, error.error_code, resp.status); + throw new OasstError(error.message ?? error, error.error_code, resp.status); } return await resp.json(); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 4eea8c1e..c8760324 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -35,7 +35,12 @@ const handler = async (req, res) => { }, }); - const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token); + let newTask; + try { + newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token); + } catch (err) { + return res.status(500).json(err); + } // Stores the new task with our database. const newRegisteredTask = await prisma.registeredTask.create({ diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx new file mode 100644 index 00000000..66ab0580 --- /dev/null +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -0,0 +1,113 @@ +import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; +import { useColorMode } from "@chakra-ui/react"; +import { useEffect, useId, useState } from "react"; +import { LoadingScreen } from "src/components/Loading/LoadingScreen"; +import { MessageView } from "src/components/Messages"; +import { TaskControls } from "src/components/Survey/TaskControls"; +import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; +import { LabelInitialPromptTask, LabelInitialPromptTaskResponse, useLabelingTask } from "src/hooks/useLabelingTask"; +import { colors } from "styles/Theme/colors"; + +const LabelInitialPrompt = () => { + const [sliderValues, setSliderValues] = useState([]); + + const { tasks, isLoading, submit, reset } = useLabelingTask({ + taskApiEndpoint: "label_initial_prompt", + }); + + const submitResponse = ({ id, task }: LabelInitialPromptTaskResponse) => { + const labels = task.valid_labels.reduce((obj, label, i) => { + obj[label] = sliderValues[i].toString(); + return obj; + }, {} as Record); + + submit(id, task.message_id, task.prompt, labels); + }; + + const { colorMode } = useColorMode(); + const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; + + if (isLoading) { + return ; + } + + if (tasks.length === 0) { + return No tasks found...; + } + + const task = tasks[0].task; + + return ( +
+ + <> +
Label Initial Prompt
+

Provide labels for the following prompt

+ + + +
+ +
+ ); +}; + +export default LabelInitialPrompt; + +// TODO: consolidate with FlaggableElement + +interface CheckboxSliderGroupProps { + labelIDs: Array; + onChange: (sliderValues: number[]) => unknown; +} + +const CheckboxSliderGroup = ({ labelIDs, onChange }: CheckboxSliderGroupProps) => { + const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0)); + + useEffect(() => { + onChange(sliderValues); + }, [sliderValues, onChange]); + + return ( + + {labelIDs.map((labelId, idx) => ( + { + const newState = sliderValues.slice(); + newState[idx] = sliderValue; + setSliderValues(newState); + }} + /> + ))} + + ); +}; + +function CheckboxSliderItem(props: { + labelId: string; + sliderValue: number; + sliderHandler: (newVal: number) => unknown; +}) { + const id = useId(); + const { colorMode } = useColorMode(); + + const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`; + + return ( + <> + + props.sliderHandler(val / 100)}> + + + + + + + ); +} From 2a5d37a07f59f229d594c57263a653db1efa6061 Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sat, 7 Jan 2023 17:25:05 +0100 Subject: [PATCH 12/69] Remove unused task elements --- .../components/TaskSelection/TaskOption.tsx | 39 ---------- .../components/TaskSelection/TaskOptions.tsx | 23 ------ .../TaskSelection/TaskSelection.tsx | 73 ------------------- website/src/components/TaskSelection/index.ts | 3 - .../src/pages/label/label_initial_prompt.tsx | 2 +- 5 files changed, 1 insertion(+), 139 deletions(-) delete mode 100644 website/src/components/TaskSelection/TaskOption.tsx delete mode 100644 website/src/components/TaskSelection/TaskOptions.tsx delete mode 100644 website/src/components/TaskSelection/TaskSelection.tsx delete mode 100644 website/src/components/TaskSelection/index.ts diff --git a/website/src/components/TaskSelection/TaskOption.tsx b/website/src/components/TaskSelection/TaskOption.tsx deleted file mode 100644 index 764efa68..00000000 --- a/website/src/components/TaskSelection/TaskOption.tsx +++ /dev/null @@ -1,39 +0,0 @@ -import { Card, CardBody, Flex, Heading } from "@chakra-ui/react"; -import Image from "next/image"; -import Link from "next/link"; - -export type OptionProps = { - img: string; - alt: string; - title: string; - link: string; -}; - -export const TaskOption = (props: OptionProps) => { - const { alt, img, title, link } = props; - return ( - - - - - {alt} - - {title} - - - - - - ); -}; diff --git a/website/src/components/TaskSelection/TaskOptions.tsx b/website/src/components/TaskSelection/TaskOptions.tsx deleted file mode 100644 index fe24b393..00000000 --- a/website/src/components/TaskSelection/TaskOptions.tsx +++ /dev/null @@ -1,23 +0,0 @@ -import { Divider, Flex, Heading } from "@chakra-ui/react"; -import React from "react"; - -export type TaskOptionsProps = { - title: string; - children: JSX.Element | JSX.Element[]; -}; - -export const TaskOptions = (props: TaskOptionsProps) => { - const { title, children } = props; - return ( - - - {title} - - - {children} - - ); -}; diff --git a/website/src/components/TaskSelection/TaskSelection.tsx b/website/src/components/TaskSelection/TaskSelection.tsx deleted file mode 100644 index 683c80e9..00000000 --- a/website/src/components/TaskSelection/TaskSelection.tsx +++ /dev/null @@ -1,73 +0,0 @@ -import { Flex } from "@chakra-ui/react"; -import { useColorMode } from "@chakra-ui/react"; -import React from "react"; - -import { TaskOption } from "./TaskOption"; -import { TaskOptions } from "./TaskOptions"; - -export const TaskSelection = () => { - const { colorMode } = useColorMode(); - const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white"; - - return ( - - - {/* */} - - - - - - {/* - Commented out while the backend does not support them. - */} - - - - - - ); -}; diff --git a/website/src/components/TaskSelection/index.ts b/website/src/components/TaskSelection/index.ts deleted file mode 100644 index d6d93973..00000000 --- a/website/src/components/TaskSelection/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export { TaskOption } from "./TaskOption"; -export { TaskOptions } from "./TaskOptions"; -export { TaskSelection } from "./TaskSelection"; diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index 66ab0580..0c3b47be 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -1,4 +1,4 @@ -import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; +import { Container, Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; import { useEffect, useId, useState } from "react"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; From 25ce928733ab47e676ce91ca29539a1c497a7d31 Mon Sep 17 00:00:00 2001 From: notmd Date: Sun, 8 Jan 2023 06:53:39 +0700 Subject: [PATCH 13/69] fix: improve `TaskControls` UI on mobile --- .../src/components/Survey/TaskControls.tsx | 21 +++++++++++-------- .../src/components/Survey/TrackedTextarea.tsx | 2 +- website/src/components/TaskInfo/TaskInfo.tsx | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/website/src/components/Survey/TaskControls.tsx b/website/src/components/Survey/TaskControls.tsx index 851e659c..ea71467c 100644 --- a/website/src/components/Survey/TaskControls.tsx +++ b/website/src/components/Survey/TaskControls.tsx @@ -1,5 +1,6 @@ import { useColorMode } from "@chakra-ui/react"; import { Flex } from "@chakra-ui/react"; +import clsx from "clsx"; import { SkipButton } from "src/components/Buttons/Skip"; import { SubmitButton } from "src/components/Buttons/Submit"; import { TaskInfo } from "src/components/TaskInfo/TaskInfo"; @@ -14,18 +15,20 @@ export interface TaskControlsProps { } export const TaskControls = (props: TaskControlsProps) => { - const extraClases = props.className || ""; const { colorMode } = useColorMode(); - - const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto"; - const taskControlClases = - colorMode === "light" - ? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}` - : `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`; - + const isLightMode = colorMode === "light"; const endTask = props.tasks[props.tasks.length - 1]; return ( -
+
Skip diff --git a/website/src/components/Survey/TrackedTextarea.tsx b/website/src/components/Survey/TrackedTextarea.tsx index f1691b72..d20107ac 100644 --- a/website/src/components/Survey/TrackedTextarea.tsx +++ b/website/src/components/Survey/TrackedTextarea.tsx @@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => { return ( -