diff --git a/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py new file mode 100644 index 00000000..19b497fa --- /dev/null +++ b/backend/alembic/versions/2023_01_19_2200-4f26fec4d204_add_ix_user_display_name_id.py @@ -0,0 +1,26 @@ +"""add ix_user_display_name_id + +Revision ID: 4f26fec4d204 +Revises: 0964ac95170d +Create Date: 2023-01-19 22:00:00 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4f26fec4d204" +down_revision = "7f0a28a156f4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_user_display_name_id", table_name="user") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 0b2db515..f2fc3181 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/", response_model=list[protocol.FrontEndUser]) -def get_users( +@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True) +def get_users_ordered_by_username( api_client_id: Optional[UUID] = None, - max_count: Optional[int] = Query(100, gt=0, le=10000), - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): ur = UserRepository(db, api_client) - users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method) - return [u.to_protocol_frontend_user() for u in users] - - -@router.get("/by_display_name") -def query_frontend_users_by_display_name( - search_text: str, - exact: bool = False, - api_client_id: UUID = None, - max_count: int = Query(20, gt=0, le=1000), - auth_method: str = None, - api_client: ApiClient = Depends(deps.get_api_client), - db: Session = Depends(deps.get_db), -): - ur = UserRepository(db, api_client) - users = ur.query_users_by_display_name( - search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, ) return [u.to_protocol_frontend_user() for u in users] diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 36cd65c9..0b31495a 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/users/{user_id}", response_model=protocol.FrontEndUser) +@router.get("/by_username", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_username( + api_client_id: Optional[UUID] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, + search_text: Optional[str] = None, + auth_method: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_username( + api_client_id=api_client_id, + gte_username=gte_username, + gt_id=gt_id, + lte_username=lte_username, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/by_display_name", response_model=list[protocol.FrontEndUser]) +def get_users_ordered_by_display_name( + api_client_id: Optional[UUID] = None, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, + auth_method: Optional[str] = None, + search_text: Optional[str] = None, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + users = ur.query_users_ordered_by_display_name( + api_client_id=api_client_id, + gte_display_name=gte_display_name, + gt_id=gt_id, + lte_display_name=lte_display_name, + lt_id=lt_id, + auth_method=auth_method, + search_text=search_text, + limit=max_count, + ) + return [u.to_protocol_frontend_user() for u in users] + + +@router.get("/{user_id}", response_model=protocol.FrontEndUser) def get_user( user_id: UUID, api_client_id: UUID = None, @@ -31,7 +85,7 @@ def get_user( return user.to_protocol_frontend_user() -@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT) def update_user( user_id: UUID, enabled: Optional[bool] = None, @@ -46,7 +100,7 @@ def update_user( ur.update_user(user_id, enabled, notes) -@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT) def delete_user( user_id: UUID, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 0fb36c22..d882a15a 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel class User(SQLModel, table=True): __tablename__ = "user" - __table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),) + __table_args__ = ( + Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True), + Index("ix_user_display_name_id", "display_name", "id", unique=True), + ) id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 7c46a026..c0c2a88d 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session +from sqlmodel import Session, and_, or_ from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -135,13 +135,16 @@ class UserRepository: self.db.add(user) return user - def query_users( + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None, - limit: Optional[int] = 20, - gt: Optional[str] = None, - lt: Optional[str] = None, + gte_username: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_username: Optional[str] = None, + lt_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -150,34 +153,52 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - users = self.db.query(User) + qry = self.db.query(User).order_by(User.username, User.id) - if api_client_id: - users = users.filter(User.api_client_id == api_client_id) + if gte_username is not None: + if gt_id: + qry = qry.filter( + or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id)) + ) + else: + qry = qry.filter(User.username >= gte_username) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_username is not None: + if lt_id: + qry = qry.filter( + or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id)) + ) + else: + qry = qry.filter(User.username <= lte_username) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) if auth_method: - users = users.filter(User.auth_method == auth_method) + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) - users = users.order_by(User.id) - - if gt: - users = users.filter(User.id > gt) - - if lt: - users = users.filter(User.id < lt).order_by(None).order_by(User.id.desc()) + if search_text: + pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) + qry = qry.filter(User.username.like(pattern)) if limit is not None: - users = users.limit(limit) + qry = qry.limit(limit) - return users.all() + return qry.all() - def query_users_by_display_name( + def query_users_ordered_by_display_name( self, - search_text: str, - exact: Optional[bool] = False, - limit: Optional[int] = 20, + gte_display_name: Optional[str] = None, + gt_id: Optional[UUID] = None, + lte_display_name: Optional[str] = None, + lt_id: Optional[UUID] = None, api_client_id: Optional[UUID] = None, auth_method: Optional[str] = None, + search_text: Optional[str] = None, + limit: Optional[int] = 100, ) -> list[User]: if not self.api_client.trusted: if not api_client_id: @@ -186,11 +207,40 @@ class UserRepository: if api_client_id != self.api_client.id: raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) - qry = self.db.query(User).order_by(User.display_name) + qry = self.db.query(User).order_by(User.display_name, User.id) - if exact: - qry = qry.filter(User.display_name == search_text) - else: + if gte_display_name is not None: + if gt_id: + qry = qry.filter( + or_( + User.display_name > gte_display_name, + and_(User.display_name == gte_display_name, User.id > gt_id), + ) + ) + else: + qry = qry.filter(User.display_name >= gte_display_name) + elif gt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if lte_display_name is not None: + if lt_id: + qry = qry.filter( + or_( + User.display_name < lte_display_name, + and_(User.display_name == lte_display_name, User.id < lt_id), + ) + ) + else: + qry = qry.filter(User.display_name <= lte_display_name) + elif lt_id: + raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR) + + if auth_method: + qry = qry.filter(User.auth_method == auth_method) + if api_client_id: + qry = qry.filter(User.api_client_id == api_client_id) + + if search_text: pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%")) qry = qry.filter(User.display_name.like(pattern)) diff --git a/model/reward/instructor/configs/deberta-v3-base.yml b/model/reward/instructor/configs/deberta-v3-base.yml index 7023709c..134cfdaa 100644 --- a/model/reward/instructor/configs/deberta-v3-base.yml +++ b/model/reward/instructor/configs/deberta-v3-base.yml @@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base learning_rate: 1e-5 scheduler: cosine gradient_checkpointing: false -gradient_accumulation_steps: 32 +gradient_accumulation_steps: 16 per_device_train_batch_size: 2 warmup_steps: 600 eval_steps: 200 diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index 822121d8..d5b10e01 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -60,6 +60,26 @@ python trainer.py --configs defaults your-model-name --deepspeed ## Dataset choices +To specify which translation pair for +[WMT](https://huggingface.co/datasets/wmt19) and +[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply +add the supported language pair at the postfix + +``` + datasets: + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja +``` + +Currently only these languages are supported via prompt translation: + +``` +ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh +``` + ## Results Experimental results in wandb diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 815c2e75..1d196fb2 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -29,6 +29,14 @@ defaults: - soda - joke - gsm8k + - dive_mt + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja + - instruct_tuning + - wmt2019_de-en - samsum - soda_dialogue cache_dir: .cache diff --git a/model/supervised_finetuning/custom_datasets/README.md b/model/supervised_finetuning/custom_datasets/README.md new file mode 100644 index 00000000..56a28574 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/README.md @@ -0,0 +1,27 @@ +# Dataset collections overview: + +currently dataset can be divided into 3 classes + +- language knowledge + + - summarization + + - translation + +- dialogue : don't let user know you are a robot + +- STEM : knowledge about the world + + - coding + + - world knowledge <= ideally we want to handle this via prefix context + +Issues and TODO: + +- as dataset are growing, how can we update this section less + +- ideally we can update the config yaml and new dataset will be download from + hub + + - one possible idea is we upload the trasform format of these dataset to the + OA hub diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 3bec37e7..2e1e4b30 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,11 +1,26 @@ -from custom_datasets.prompt_dialogue import PromptGeneratedDataset +""" + High level functions for model training +""" +from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT from custom_datasets.summarization import SummarizationDataset +from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination +from custom_datasets.translation import WMT2019, DiveMT, TEDTalk from sklearn.model_selection import train_test_split from torch.utils.data import Subset QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"] +SUMMARIZATION_DATASETS = [ + "xsum", + "cnn_dailymail", + "samsum", + "multi_news", + "scitldr", + "billsum", + "debate_sum", + "tldr_news", +] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"] def train_val_dataset(dataset, val_split=0.2): @@ -25,14 +40,34 @@ def get_one_dataset(conf, dataset_name): elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - val_name = "validation" if dataset_name not in ["billsum"] else "test" - eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + if dataset_name == "debate_sum": + train, eval = train_val_dataset(train, val_split=0.2) + else: + val_name = "validation" if dataset_name not in ["billsum"] else "test" + eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + elif "ted_trans" in dataset_name: + language_pair = dataset_name.split("_")[-1] + dataset = TEDTalk(pair=language_pair, split="train") + train, eval = train_val_dataset(dataset, val_split=0.2) + elif "wmt2019" in dataset_name: + language_pair = dataset_name.split("_")[-1] + train = WMT2019(pair=language_pair, split="train") + eval = WMT2019(pair=language_pair, split="validation") + elif dataset_name == "dive_mt": + dataset = DiveMT() + train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": dataset = PromptGeneratedDataset(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "prosocial_dialogue": + train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") + elif dataset_name == "explain_prosocial": + train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") elif dataset_name == "soda": dataset = SODA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.1) @@ -42,6 +77,9 @@ def get_one_dataset(conf, dataset_name): elif dataset_name == "joke": dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "instruct_tuning": + dataset = InstructionTuning(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 2efe160f..719fa0d6 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -25,6 +25,7 @@ class DialogueDataCollator: for feature_one in features: assert len(feature_one) % 2 == 0, "Number of messages must be even" + # TODO: we should push this to dataset __getitem__ messages = [ (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") + x diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 372ea27f..4a1d83a3 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -1,3 +1,4 @@ +import json import os from urllib.request import urlopen @@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset): we are ignoring results with multiple lines for now """ + name = "prompt_dialogue" url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt" def __init__(self, cache_dir) -> None: @@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +class InstructionTuning(Dataset): + """ + We have seen some promising capabilities from instruction tuning + with the following mix of datasets that are derived from datasets + available online. + The files for this data are in json format as a list of tuples + where each tuple is (source,instruction_response_pair) + + - instruction_tuning_dataset_alpha_part1.json + - instruction_tuning_dataset_alpha_part2.json + + Not to be confused with unatural instruction + """ + + name = "instruction_dataset" + url_part_2 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json" + ) + url_part_1 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json" + ) + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self.pairs = [] + for file_link in [self.url_part_1, self.url_part_2]: + basename = file_link.split("/")[-1] + instruction_tune_file = os.path.join(cache_dir, basename) + if not os.path.exists(instruction_tune_file): + with urlopen(file_link) as file: + content = file.read().decode() + with open(instruction_tune_file, "w", encoding="utf-8") as fout: + fout.write(content) + + with open(instruction_tune_file, "r", encoding="utf-8") as f: + datasets = json.load(f) + for row in datasets: + _, response_pair = row + question, answer = response_pair.split("\n\n", maxsplit=1) + answer = answer.replace("<|endoftext|>", "").strip() + self.pairs.append((question, answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index d191c56c..7d9c7f48 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -1,11 +1,18 @@ +""" + Open / close book QA datasets +""" import json import os +import re from urllib.request import urlopen import numpy as np from datasets import load_dataset from torch.utils.data import Dataset +# @agoryuno contributed this +re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") + QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} @@ -75,6 +82,9 @@ class QADataset(Dataset): class WebGPT(Dataset): + + name = "webgpt" + def __init__(self) -> None: super().__init__() @@ -89,7 +99,9 @@ class WebGPT(Dataset): self.index2question[len(self.index2question)] = question # only keep the best answer - questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + questions[question] = re_reference_remove.sub( + "", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + ) self.questions = questions @@ -103,6 +115,9 @@ class WebGPT(Dataset): class SODA(Dataset): + + name = "soda" + def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] @@ -207,8 +222,8 @@ class SODADialogue(Dataset): class JokeExplaination(Dataset): - """ """ + name = "joke" url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" def __init__(self, cache_dir) -> None: @@ -240,3 +255,6 @@ class JokeExplaination(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +# https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 69e4b51d..2a097fe7 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -1,3 +1,6 @@ +""" + Summarize different spectrum of documents +""" import random from datasets import load_dataset @@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = { } summarization_config_mapping = { - "cnn_dailymail": ("3.0.0",), - "samsum": (), - "xsum": (), - "multi_news": (), - "scitldr": ("AIC",), - "billsum": (), - "reddit": (), + "cnn_dailymail": ( + "cnn_dailymail", + "3.0.0", + ), + "samsum": ("samsum",), + "xsum": ("xsum",), + "multi_news": ("multi_news",), + "scitldr": ( + "scitldr", + "AIC", + ), + "billsum": ("billsum",), + "reddit": ("reddit",), + "tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news + "debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum } summarization_name_mapping = { @@ -29,6 +40,8 @@ summarization_name_mapping = { "scitldr": ("source", "target"), "billsum": ("text", "summary"), "reddit": ("content", "summary"), + "tldr_news": ("content", "headline"), + "debate_sum": ("Full-Document", "Extract"), } @@ -43,7 +56,7 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split): self.name = dataset - self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py new file mode 100644 index 00000000..815ac722 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -0,0 +1,68 @@ +""" + SFT dataset to reject toxic questions + +""" +import random + +from datasets import load_dataset +from torch.utils.data import Dataset + + +class ProsocialDialogueExplaination(Dataset): + name = "prosocial_explain" + TEMPLATE = [ + # 0 : reply or sentence of interest, 1 : reason of caution + ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), + ("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"), + ("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"), + ("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"), + ] + + def __init__(self, split="train", cache_dir=".cache") -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]): + (prompt_template, answer_template) = random.choice(self.TEMPLATE) + self.pairs.append( + ( + prompt_template.format(row["context"], safety_annotation), + answer_template.format(safe_answer, safety_annotation), + ) + ) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] + + +class ProsocialDialogue(Dataset): + name = "prosocial_dialogue" + """ + ProsocialDialog, we set up a human-AI collaborative data creation framework, + where GPT-3 generates the potentially unsafe utterances, and crowdworkers + provide prosocial responses to them. This approach allows us to circumvent + two substantial challenges: + (1) there are no available large-scale corpora of multiturn prosocial conversations + between humans + (2) asking humans to write unethical, toxic, or problematic utterances could result + in psychological harms (Roberts, 2017; Steiger et al., 2021). + """ + PREFIX = "You are now a prosocial chatbot, be caution and casual when reply" + + def __init__(self, split="train", cache_dir=".cache") -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for answer in row["rots"]: + self.pairs.append((self.PREFIX + row["context"], answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py new file mode 100644 index 00000000..694d31ce --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -0,0 +1,142 @@ +""" + List of translation dataset + + GroNLP/divemt + + fill in the blanks : https://huggingface.co/datasets/m_lama + +""" +import random + +from datasets import load_dataset +from torch.utils.data import Dataset + +# postfix prompt +TRANSLATION_PROMPT = { + "zh": [ # simplified or any chinese which was not mentioned + "Translate to chinese simplified: {}", + "{}, translate to chinese", + "{} give me the chinese translation", + "翻译成中文: {}", + "{} 这句中文翻译怎麽写?", + "我需要这句话的中文翻译: {}", + ], + "zh-tw": [ # WMT code + "{}. Translate to chinese traditional", + "{}, translate to chinese", + "{}. get chinese translation", + "中文翻譯: {}", + "幫我翻譯成中文: '{}'", + "{} 這句中文翻譯怎麼寫?", + ], + "ja": [ + "{}: help me translate to japanese", + "Need japanese translation: {}", + "{}: にほんごやくをよこす", + "{}: にほんごやくをおくれ", + "{}: にほんごやくを じょす", + "give me the japanese translation, {}", + ], + "de": [ + "{}: translate to german", + "give me the german translation {}", + "I want german translation {}", + "{}, ins Deutsche übersetzen", + "{}, Übersetzen ins Deutsche", + ], + "fr": [ + "{}. translate to french", + "{} write in french", + "{} french translation", + "{} ,donnez moi la traduction française", + ], + "ko": [ + "{}. translate to Korean", + "how do we write in korean: {}", + "give me the korean translation: {}", + "{}, 한국어 번역을 해주세요", + ], + "ms": [ + "{} translate to malay", + "{} how do we write in Malay", + "{} give me the malay translation", + "{} , berikan saya terjemahan dalam bahasa melayu", + "{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu", + ], + "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], + "ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"], + "tr": ["{}. türkçeye çevi̇ri̇n", "{} write in turkish", "turkish translation: '{}'", "türkçeye çevi̇rmek: {}"], + "it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"], + "nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"], + "vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"], + "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], +} + + +class TranslationPair(Dataset): + def __init__(self) -> None: + super().__init__() + self.pairs = [] + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + return self.pairs[index] + + +class WMT2019(TranslationPair): + def __init__(self, pair="zh-en", split="train") -> None: + super().__init__() + dataset = load_dataset("wmt19", pair)[split] + self.pairs = [] + src, tgt = pair.split("-") + for row in dataset: + row = row["translation"] + if random.random() > 0.5: + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) + self.pairs.append((source, row[tgt])) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) + self.pairs.append((source, row[src])) + + +class DiveMT(TranslationPair): + + REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"} + + def __init__(self, split="train") -> None: + super().__init__() + dataset = load_dataset("GroNLP/divemt", "main")[split] + tgt, src = "tgt_text", "src_text" + for row in dataset: + # ISO 639-2 + lang_code_2 = row["subject_id"].split("_")[0] + lang_code = self.REMAP[lang_code_2] + if lang_code not in TRANSLATION_PROMPT: + continue + + if random.random() > 0.5: + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src]) + self.pairs.append((source, row[tgt])) + else: # translating in reverse direction + lang_code = "en" + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt]) + self.pairs.append((source, row[src])) + + +class TEDTalk(TranslationPair): + # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean + + def __init__(self, pair="de-ja", split="train", year="2016") -> None: + super().__init__() + dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split] + src, tgt = pair.split("-") + for row in dataset: + row = row["translation"] + if random.random() > 0.5: + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) + self.pairs.append((source, row[tgt])) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) + self.pairs.append((source, row[src])) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index c9363303..3b59f289 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") - for dataset_name in others + qa_base + summarize_base: + for dataset_name in translation + others + summarize_base + qa_base: print(dataset_name) train, eval = get_one_dataset(config, dataset_name) # sanity check @@ -48,7 +49,3 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: assert batch["targets"].shape[1] <= 512 - - -if __name__ == "__main__": - test_collate_fn() diff --git a/website/cypress/contract/oasst_api_contract_tests.cy.ts b/website/cypress/contract/oasst_api_contract_tests.cy.ts index cf1c7506..d2ffeba3 100644 --- a/website/cypress/contract/oasst_api_contract_tests.cy.ts +++ b/website/cypress/contract/oasst_api_contract_tests.cy.ts @@ -1,34 +1,27 @@ import { OasstApiClient, OasstError } from "src/lib/oasst_api_client"; +import type { BackendUserCore } from "src/types/Users"; describe("Contract test for Oasst API", function () { // Assumes this is running the mock server. const oasstApiClient = new OasstApiClient("http://localhost:8080", "test"); + const testUser = { + id: "abcd", + display_name: "test", + auth_method: "local", + } as BackendUserCore; + it("can fetch a task", async () => { - expect( - await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }) - ).to.be.not.null; + expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null; }); it("can ack a task", async () => { - const task = await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }); + const task = await oasstApiClient.fetchTask("random", testUser); expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null; }); it("can record a taskInteraction", async () => { - const task = await oasstApiClient.fetchTask("random", { - sub: "test", - name: "test", - email: "test", - }); + const task = await oasstApiClient.fetchTask("random", testUser); expect( await oasstApiClient.interactTask( "text_reply_to_message", @@ -36,11 +29,7 @@ describe("Contract test for Oasst API", function () { "321", "1", { text: "Test" }, - { - sub: "test", - name: "test", - email: "test", - } + testUser ) ).to.be.not.null; }); diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 0b2df79c..99edf6c3 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -1,4 +1,17 @@ { + "about": "About", + "account_settings": "Account", + "connect": "Connect", + "conversational": "Conversational AI for everyone.", + "dashboard": "Dashboard", "discord": "Discord", - "github": "GitHub" + "docs": "Docs", + "github": "GitHub", + "legal": "Legal", + "privacy_policy": "Privacy Policy", + "report_a_bug": "Report a Bug", + "sign_in": "Sign In", + "sign_out": "Sign Out", + "terms_of_service": "Terms of Service", + "title": "Open Assistant" } diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx index 1b421126..e2bafac3 100644 --- a/website/src/components/Dashboard/TaskOption.tsx +++ b/website/src/components/Dashboard/TaskOption.tsx @@ -1,19 +1,19 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Link from "next/link"; -import { TaskTypes } from "../Tasks/TaskTypes"; +import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes"; -export const TaskOption = ({ displayTaskCategories }) => { +export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => { const backgroundColor = useColorModeValue("white", "gray.700"); return ( - {displayTaskCategories.map((category, categoryIndex) => ( -
- {category} + {displayTaskCategories.map((category) => ( +
+ {TaskCategoryLabels[category]} - {TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => ( - + {TaskTypes.filter((task) => task.category === category).map((item) => ( + { }, [data, isLoading]); const { trigger } = useSWRMutation("/api/set_label", post, { - onSuccess: () => { - setIsEditing.off(); - }, + onSuccess: setIsEditing.off, }); const submitResponse = () => { @@ -149,7 +147,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => { isLazy lazyBehavior="keepMounted" > - + {props.children} diff --git a/website/src/components/Footer.tsx b/website/src/components/Footer.tsx index b239708a..68cd7c01 100644 --- a/website/src/components/Footer.tsx +++ b/website/src/components/Footer.tsx @@ -1,9 +1,11 @@ import { Box, Divider, Flex, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; +import { useTranslation } from "next-i18next"; import { useMemo } from "react"; export function Footer() { + const { t } = useTranslation(); const { colorMode } = useColorMode(); const backgroundColor = colorMode === "light" ? "white" : "gray.800"; const textColor = colorMode === "light" ? "black" : "gray.300"; @@ -33,10 +35,10 @@ export function Footer() { - Open Assistant + {t("title")} - Conversational AI for everyone. + {t("conversational")} @@ -45,23 +47,23 @@ export function Footer() { - Legal + {t("legal")} - - + + - Connect + {t("connect")} - - + + - About + {t("about")} - + diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index de5abec2..a1b36123 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -1,7 +1,8 @@ -import { Box, Button, Text, Flex } from "@chakra-ui/react"; +import { Box, Button, Flex, Text } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; import { useSession } from "next-auth/react"; +import { useTranslation } from "next-i18next"; import { Flags } from "react-feature-flags"; import { FaUser } from "react-icons/fa"; @@ -23,7 +24,8 @@ function AccountButton() { ); } -export function Header(props) { +export function Header() { + const { t } = useTranslation(); const { data: session } = useSession(); const homeURL = session ? "/dashboard" : "/"; @@ -34,7 +36,7 @@ export function Header(props) { logo - Open Assistant + {t("title")} diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 99ec01f1..6fdde69e 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -13,6 +13,7 @@ import { } from "@chakra-ui/react"; import NextLink from "next/link"; import { signOut, useSession } from "next-auth/react"; +import { useTranslation } from "next-i18next"; import React, { ElementType, useCallback } from "react"; import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi"; @@ -25,6 +26,7 @@ interface MenuOption { } export function UserMenu() { + const { t } = useTranslation(); const borderColor = useColorModeValue("gray.300", "gray.600"); const handleSignOut = useCallback(() => { signOut({ callbackUrl: "/" }); @@ -36,23 +38,23 @@ export function UserMenu() { } const options: MenuOption[] = [ { - name: "Dashboard", + name: t("dashboard"), href: "/dashboard", - desc: "Dashboard", + desc: t("dashboard"), icon: FiLayout, isExternal: false, }, { - name: "Account Settings", + name: t("account_settings"), href: "/account", - desc: "Account Settings", + desc: t("account_settings"), icon: FiSettings, isExternal: false, }, { - name: "Report a Bug", + name: t("report_a_bug"), href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose", - desc: "Report a Bug", + desc: t("report_a_bug"), icon: FiAlertTriangle, isExternal: true, }, @@ -60,9 +62,9 @@ export function UserMenu() { if (session.user.role === "admin") { options.unshift({ - name: "Admin Dashboard", + name: t("admin_dashboard"), href: "/admin", - desc: "Admin Dashboard", + desc: t("admin_dashboard"), icon: FiShield, isExternal: false, }); @@ -105,7 +107,7 @@ export function UserMenu() { diff --git a/website/src/components/Hero.tsx b/website/src/components/Hero.tsx index 4605e9e2..d401e47e 100644 --- a/website/src/components/Hero.tsx +++ b/website/src/components/Hero.tsx @@ -2,8 +2,8 @@ import { Box, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import { useTranslation } from "next-i18next"; -import { Container } from "./Container"; import { AnimatedCircles } from "./AnimatedCircles"; +import { Container } from "./Container"; export function Hero() { const { t } = useTranslation("index"); diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 70a2ce2c..484d16ec 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -23,7 +23,7 @@ export const getDefaultLayout = (page: React.ReactElement) => ( export const getTransparentHeaderLayout = (page: React.ReactElement) => (
-
+
{page}
@@ -31,7 +31,7 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => ( export const getDashboardLayout = (page: React.ReactElement) => ( -
+
( export const getAdminLayout = (page: React.ReactElement) => (
-
+
+ {messages.map((item) => ( ))} diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 8e9d03b6..d18bd910 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,6 +1,8 @@ -import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react"; +import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; import Link from "next/link"; +import { useRouter } from "next/router"; +import { useCallback, useMemo } from "react"; import { FlaggableElement } from "src/components/FlaggableElement"; import { Message } from "src/types/Conversation"; @@ -10,47 +12,48 @@ interface MessageTableEntryProps { } export function MessageTableEntry(props: MessageTableEntryProps) { + const router = useRouter(); + const { item } = props; + + const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]); + const backgroundColor = useColorModeValue("gray.100", "gray.700"); const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B"); - const avatarColor = useColorModeValue("white", "black"); const borderColor = useColorModeValue("blackAlpha.200", "whiteAlpha.200"); + const inlineAvatar = useBreakpointValue({ base: true, sm: false }); + + const avatar = useMemo( + () => ( + + ), + [borderColor, inlineAvatar, item.is_assistant] + ); + return ( - - + {!inlineAvatar && avatar} + + {inlineAvatar && avatar} + {item.text} - {props.enabled ? ( - - - - {item.text} - - - - ) : ( - - {item.text} - - )} ); diff --git a/website/src/components/Survey/SurveyCard.tsx b/website/src/components/Survey/SurveyCard.tsx index 5a78ce2b..6101f787 100644 --- a/website/src/components/Survey/SurveyCard.tsx +++ b/website/src/components/Survey/SurveyCard.tsx @@ -1,22 +1,19 @@ import { Box, BoxProps, useColorModeValue } from "@chakra-ui/react"; +import clsx from "clsx"; +import { PropsWithChildren } from "react"; -interface SurveyCardProps { - className?: string; - children: React.ReactNode; -} - -export const SurveyCard = (props: SurveyCardProps) => { +export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => { const backgroundColor = useColorModeValue("white", "gray.700"); const BoxClasses: BoxProps = { gap: "2", borderRadius: "xl", shadow: "base", - className: "p-4 sm:p-6", + className: clsx("p-4 sm:p-6", props.className), }; return ( - + {props.children} ); diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 868a9fb8..4c6da92c 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -1,5 +1,5 @@ export enum TaskCategory { - Tasks = "Tasks", + Random = "Random", Create = "Create", Evaluate = "Evaluate", Label = "Label", @@ -20,12 +20,19 @@ export interface TaskInfo { unchanged_message?: string; } +export const TaskCategoryLabels: { [key in TaskCategory]: string } = { + [TaskCategory.Random]: "I'm feeling lucky", + [TaskCategory.Create]: "Create", + [TaskCategory.Evaluate]: "Evaluate", + [TaskCategory.Label]: "Label", +}; + export const TaskTypes: TaskInfo[] = [ // general/random { label: "Start a Task", desc: "Help us improve Open Assistant by starting a random task.", - category: TaskCategory.Tasks, + category: TaskCategory.Random, pathname: "/tasks/random", help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting", type: "random", @@ -121,7 +128,7 @@ export const TaskTypes: TaskInfo[] = [ category: TaskCategory.Label, pathname: "/label/label_prompter_reply", help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_prompter_reply", - overview: "Given the following discussion, provide labels for the final prompt", + overview: "Given the following discussion, provide labels for the final prompt.", type: "label_prompter_reply", mode: "full", update_type: "text_labels", diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx index e8f490ae..d258ff47 100644 --- a/website/src/hooks/tasks/useGenericTaskAPI.tsx +++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx @@ -1,23 +1,22 @@ import { useState } from "react"; import { get, post } from "src/lib/api"; -import { BaseTask, TaskResponse } from "src/types/Task"; +import { BaseTask, TaskResponse, TaskType as TaskTypeEnum } from "src/types/Task"; import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; -export const useGenericTaskAPI = (taskApiEndpoint: string) => { +export const useGenericTaskAPI = (taskType: TaskTypeEnum) => { type ConcreteTaskResponse = TaskResponse; const [tasks, setTasks] = useState([]); - const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskApiEndpoint, get, { + const { isLoading, mutate, error } = useSWRImmutable("/api/new_task/" + taskType, get, { onSuccess: (data) => setTasks([data]), revalidateOnMount: true, dedupingInterval: 500, }); const { trigger } = useSWRMutation("/api/update_task", post, { - onSuccess: async (response) => { - const newTask: ConcreteTaskResponse = response; + onSuccess: async (newTask: ConcreteTaskResponse) => { setTasks((oldTasks) => [...oldTasks, newTask]); mutate(); }, diff --git a/website/src/lib/default_static_props.ts b/website/src/lib/default_static_props.ts new file mode 100644 index 00000000..365099cf --- /dev/null +++ b/website/src/lib/default_static_props.ts @@ -0,0 +1,7 @@ +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; + +export const getDefaultStaticProps = async ({ locale }) => ({ + props: { + ...(await serverSideTranslations(locale)), + }, +}); diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index e1e68103..7db6e3c2 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -1,7 +1,7 @@ -import { JWT } from "next-auth/jwt"; import type { Message } from "src/types/Conversation"; import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; -import type { BackendUser } from "src/types/Users"; +import type { AvailableTasks } from "src/types/Task"; +import type { BackendUser, BackendUserCore } from "src/types/Users"; export class OasstError { message: string; @@ -108,14 +108,10 @@ export class OasstApiClient { // TODO return a strongly typed Task? // This method is used to store a task in RegisteredTask.task. // This is a raw Json type, so we can't use it to strongly type the task. - async fetchTask(taskType: string, userToken: JWT): Promise { + async fetchTask(taskType: string, user: BackendUserCore): Promise { return this.post("/api/v1/tasks/", { type: taskType, - user: { - id: userToken.sub, - display_name: userToken.name, - auth_method: "local", - }, + user, }); } @@ -140,15 +136,11 @@ export class OasstApiClient { messageId: string, userMessageId: string, content: object, - userToken: JWT + user: BackendUserCore ): Promise { return this.post("/api/v1/tasks/interaction", { type: updateType, - user: { - id: userToken.sub, - display_name: userToken.name, - auth_method: "local", - }, + user, task_id: taskId, message_id: messageId, user_message_id: userMessageId, @@ -224,6 +216,13 @@ export class OasstApiClient { async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise { return this.get(`/api/v1/leaderboards/${time_frame}`); } + + /** + * Returns the counts of all tasks (some might be zero) + */ + async fetch_available_tasks(user: BackendUserCore): Promise { + return this.post(`/api/v1/tasks/availability`, user); + } } const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY); diff --git a/website/src/lib/users.ts b/website/src/lib/users.ts new file mode 100644 index 00000000..2aa8c708 --- /dev/null +++ b/website/src/lib/users.ts @@ -0,0 +1,38 @@ +import prisma from "src/lib/prismadb"; +import type { BackendUserCore } from "src/types/Users"; + +/** + * Returns a `BackendUserCore` that can be used for interacting with the Backend service. + * + * @param {string} id The user's web auth id. + * + * @return {BackendUserCore} The most specific auth type and id for the user. + */ +const getBackendUserCore = async (id: string) => { + const user = await prisma.user.findUnique({ + where: { id }, + select: { + id: true, + name: true, + accounts: true, + }, + }); + + // If there are no linked accounts, just use what we have locally. + if (user.accounts.length === 0) { + return { + id: user.id, + display_name: user.name, + auth_method: "local", + } as BackendUserCore; + } + + // Otherwise, use the first linked account that the user created. + return { + id: user.accounts[0].providerAccountId, + display_name: user.name, + auth_method: user.accounts[0].provider, + } as BackendUserCore; +}; + +export { getBackendUserCore }; diff --git a/website/src/pages/404.tsx b/website/src/pages/404.tsx index afe1d080..d4c58b54 100644 --- a/website/src/pages/404.tsx +++ b/website/src/pages/404.tsx @@ -3,6 +3,7 @@ import Head from "next/head"; import { FiAlertTriangle } from "react-icons/fi"; import { EmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; function Error() { return ( diff --git a/website/src/pages/500.tsx b/website/src/pages/500.tsx index 49eb2950..378bdfff 100644 --- a/website/src/pages/500.tsx +++ b/website/src/pages/500.tsx @@ -3,6 +3,7 @@ import Head from "next/head"; import { FiAlertTriangle } from "react-icons/fi"; import { EmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; function ServerError() { return ( diff --git a/website/src/pages/about.tsx b/website/src/pages/about.tsx index 490a6095..01182be0 100644 --- a/website/src/pages/about.tsx +++ b/website/src/pages/about.tsx @@ -4,6 +4,7 @@ import { Container } from "src/components/Container"; import Roadmap from "src/components/Roadmap"; import Services from "src/components/Services"; import Vision from "src/components/Vision"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const AboutPage = () => { return ( diff --git a/website/src/pages/account/edit.tsx b/website/src/pages/account/edit.tsx index fe8e8981..52af7e5e 100644 --- a/website/src/pages/account/edit.tsx +++ b/website/src/pages/account/edit.tsx @@ -4,6 +4,7 @@ import Router from "next/router"; import { useSession } from "next-auth/react"; import React from "react"; import { Control, useForm, useWatch } from "react-hook-form"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; export default function Account() { const { data: session } = useSession(); diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx index d26fc842..88964b3b 100644 --- a/website/src/pages/account/index.tsx +++ b/website/src/pages/account/index.tsx @@ -1,8 +1,11 @@ -import { Button } from "@chakra-ui/react"; +import { Button, Divider, Flex, Grid, Icon, Text } from "@chakra-ui/react"; import Head from "next/head"; import Link from "next/link"; import { useSession } from "next-auth/react"; import React from "react"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import { MdOutlineEdit } from "react-icons/md"; +import { SurveyCard } from "src/components/Survey/SurveyCard"; export default function Account() { const { data: session } = useSession(); @@ -19,15 +22,28 @@ export default function Account() { content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world." /> -
-
-

{session.user.name || "No username"}

- -

{session.user.email}

-
-
+
+ + + + Your Account + + + + Username + + {session.user.name ?? "(No username)"} + + + + + Email + {session.user.email ?? "(No Email)"} + +

+
+
+
); } diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx index 397230bd..ede9f59c 100644 --- a/website/src/pages/admin/index.tsx +++ b/website/src/pages/admin/index.tsx @@ -4,6 +4,7 @@ import { useSession } from "next-auth/react"; import { useEffect } from "react"; import { getAdminLayout } from "src/components/Layout"; import { UserTable } from "src/components/UserTable"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; /** * Provides the admin index page that will display a list of users and give diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index 88bfced4..b53bb7c0 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -3,6 +3,7 @@ import { InferGetServerSidePropsType } from "next"; import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { useEffect } from "react"; import { useForm } from "react-hook-form"; import { getAdminLayout } from "src/components/Layout"; @@ -111,7 +112,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType { + const user = await getBackendUserCore(token.sub); + const availableTasks = await oasstApiClient.fetch_available_tasks(user); + res.status(200).json(availableTasks); +}); + +export default handler; diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts index e77c5eb2..c8255b18 100644 --- a/website/src/pages/api/new_task/[task_type].ts +++ b/website/src/pages/api/new_task/[task_type].ts @@ -1,6 +1,7 @@ import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { getBackendUserCore } from "src/lib/users"; /** * Returns a new task created from the Task Backend. We do a few things here: @@ -14,9 +15,10 @@ const handler = withoutRole("banned", async (req, res, token) => { // Fetch the new task. const { task_type } = req.query; + const user = await getBackendUserCore(token.sub); let task; try { - task = await oasstApiClient.fetchTask(task_type as string, token); + task = await oasstApiClient.fetchTask(task_type as string, user); } catch (err) { console.error(err); res.status(500).json(err); diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts index 02982daa..c547503a 100644 --- a/website/src/pages/api/update_task.ts +++ b/website/src/pages/api/update_task.ts @@ -2,6 +2,7 @@ import { Prisma } from "@prisma/client"; import { withoutRole } from "src/lib/auth"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; +import { getBackendUserCore } from "src/lib/users"; /** * Stores the task interaction with the Task Backend and then returns the next task generated. @@ -39,9 +40,10 @@ const handler = withoutRole("banned", async (req, res, token) => { }, }); + const user = await getBackendUserCore(token.sub); let newTask; try { - newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, token); + newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, user); } catch (err) { console.error(JSON.stringify(err)); return res.status(500).json(err); diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index ccadf55f..e3757190 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -5,6 +5,7 @@ import Head from "next/head"; import Link from "next/link"; import { useRouter } from "next/router"; import { ClientSafeProvider, getProviders, signIn } from "next-auth/react"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import React, { useEffect, useRef, useState } from "react"; import { useForm } from "react-hook-form"; import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; @@ -47,7 +48,6 @@ interface SigninProps { function Signin({ providers }: SigninProps) { const router = useRouter(); const { discord, email, github, credentials } = providers; - const emailEl = useRef(null); const [error, setError] = useState(""); useEffect(() => { @@ -151,7 +151,7 @@ function Signin({ providers }: SigninProps) { Signin.getLayout = (page) => (
-
+
{page}
@@ -209,11 +209,12 @@ const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSaf ); }; -export const getServerSideProps: GetServerSideProps = async () => { +export const getServerSideProps: GetServerSideProps = async ({ locale }) => { const providers = await getProviders(); return { props: { providers, + ...(await serverSideTranslations(locale, ["common"])), }, }; }; diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx index cceeaf4e..1c83eb23 100644 --- a/website/src/pages/create/assistant_reply.tsx +++ b/website/src/pages/create/assistant_reply.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const AssistantReply = () => { const { tasks, isLoading, reset, trigger } = useCreateAssistantReply(); diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index 6a51ca25..639df68f 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const InitialPrompt = () => { const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt(); diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index 8d2981e5..5898439c 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const UserReply = () => { const { tasks, isLoading, reset, trigger } = useCreatePrompterReply(); diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 78a47fd4..e0b8bba4 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,10 +1,20 @@ import { Flex } from "@chakra-ui/react"; import Head from "next/head"; +import { useMemo } from "react"; import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard"; import { getDashboardLayout } from "src/components/Layout"; import { TaskCategory } from "src/components/Tasks/TaskTypes"; +import { get } from "src/lib/api"; +import type { AvailableTasks, TaskType } from "src/types/Task"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; +import useSWRImmutable from "swr/immutable"; const Dashboard = () => { + const { data } = useSWRImmutable("/api/available_tasks", get); + + // TODO: show only these tasks: + const availableTasks = useMemo(() => filterAvailableTasks(data ?? {}), [data]); + return ( <> @@ -13,13 +23,19 @@ const Dashboard = () => { - + ); }; -Dashboard.getLayout = (page) => getDashboardLayout(page); +Dashboard.getLayout = getDashboardLayout; export default Dashboard; + +const filterAvailableTasks = (availableTasks: Partial) => + Object.entries(availableTasks) + .filter(([_, count]) => count > 0) + .sort((a, b) => b[1] - a[1]) + .map(([taskType]) => taskType) as TaskType[]; diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx index 695fbfdc..da79d92f 100644 --- a/website/src/pages/evaluate/rank_assistant_replies.tsx +++ b/website/src/pages/evaluate/rank_assistant_replies.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankAssistantReplies = () => { const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask(); diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx index 4eaaa110..f23fc0ed 100644 --- a/website/src/pages/evaluate/rank_initial_prompts.tsx +++ b/website/src/pages/evaluate/rank_initial_prompts.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankInitialPrompts = () => { const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask(); diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx index dd23030e..cee82b87 100644 --- a/website/src/pages/evaluate/rank_user_replies.tsx +++ b/website/src/pages/evaluate/rank_user_replies.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const RankUserReplies = () => { const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask(); diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx index 5cb45278..07a6cb1c 100644 --- a/website/src/pages/label/label_assistant_reply.tsx +++ b/website/src/pages/label/label_assistant_reply.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelAssistantReply = () => { const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask(); diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx index d7c1d4b2..8735044f 100644 --- a/website/src/pages/label/label_initial_prompt.tsx +++ b/website/src/pages/label/label_initial_prompt.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelInitialPrompt = () => { const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask(); diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx index b48e6aab..17164e11 100644 --- a/website/src/pages/label/label_prompter_reply.tsx +++ b/website/src/pages/label/label_prompter_reply.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const LabelPrompterReply = () => { const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask(); diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index e53b0c52..f79dac52 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -2,6 +2,7 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-u import Head from "next/head"; import { getDashboardLayout } from "src/components/Layout"; import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; import { LeaderboardTimeFrame } from "src/types/Leaderboard"; const Leaderboard = () => { diff --git a/website/src/pages/messages/[id]/index.tsx b/website/src/pages/messages/[id]/index.tsx index f55c03cc..51c28c42 100644 --- a/website/src/pages/messages/[id]/index.tsx +++ b/website/src/pages/messages/[id]/index.tsx @@ -1,5 +1,6 @@ import { Box, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; +import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { MessageLoading } from "src/components/Loading/MessageLoading"; import { MessageTableEntry } from "src/components/Messages/MessageTableEntry"; @@ -48,10 +49,13 @@ const MessageDetail = ({ id }: { id: string }) => { ); }; -MessageDetail.getInitialProps = async ({ query }) => { - const { id } = query; - return { id }; -}; - MessageDetail.getLayout = (page) => getDashboardLayout(page); + +export const getServerSideProps = async ({ locale, query }) => ({ + props: { + id: query.id, + ...(await serverSideTranslations(locale, ["common"])), + }, +}); + export default MessageDetail; diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 627a8b18..3b6e342e 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout"; import { MessageTable } from "src/components/Messages/MessageTable"; import { get } from "src/lib/api"; import useSWRImmutable from "swr/immutable"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const MessagesDashboard = () => { const boxBgColor = useColorModeValue("white", "gray.800"); diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx index 1c94b669..f84dc1e8 100644 --- a/website/src/pages/privacy-policy.tsx +++ b/website/src/pages/privacy-policy.tsx @@ -3,6 +3,7 @@ import Head from "next/head"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const PrivacyPolicy = () => { const backgroundColor = useColorModeValue("gray.100", "gray.800"); diff --git a/website/src/pages/tasks/random.tsx b/website/src/pages/tasks/random.tsx index d2e850f5..be1809c3 100644 --- a/website/src/pages/tasks/random.tsx +++ b/website/src/pages/tasks/random.tsx @@ -4,9 +4,10 @@ import { getDashboardLayout } from "src/components/Layout"; import { LoadingScreen } from "src/components/Loading/LoadingScreen"; import { Task } from "src/components/Tasks/Task"; import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI"; +import { TaskType } from "src/types/Task"; const RandomTask = () => { - const { tasks, isLoading, trigger, reset } = useGenericTaskAPI("random"); + const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random); if (isLoading) { return ; diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx index b0e298ba..41269bdf 100644 --- a/website/src/pages/terms-of-service.tsx +++ b/website/src/pages/terms-of-service.tsx @@ -3,6 +3,7 @@ import Head from "next/head"; import { getTransparentHeaderLayout } from "src/components/Layout"; import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard"; import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard"; +export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props"; const TermsOfService = () => { const TermsData = [ diff --git a/website/src/types/Task.ts b/website/src/types/Task.ts index d58f892c..8e5ada44 100644 --- a/website/src/types/Task.ts +++ b/website/src/types/Task.ts @@ -10,6 +10,8 @@ export const enum TaskType { label_initial_prompt = "label_initial_prompt", label_prompter_reply = "label_prompter_reply", label_assistant_reply = "label_assistant_reply", + + random = "random", } // we need to reconsider how to handle task content types @@ -32,3 +34,5 @@ export interface TaskResponse { userId: string; task: Task; } + +export type AvailableTasks = { [taskType in TaskType]: number }; diff --git a/website/src/types/Users.ts b/website/src/types/Users.ts index eeb1903a..39d2a663 100644 --- a/website/src/types/Users.ts +++ b/website/src/types/Users.ts @@ -1,7 +1,4 @@ -/** - * Reports the Backend's knowledge of a user. - */ -export interface BackendUser { +export interface BackendUserCore { /** * The user's unique ID according to the `auth_method`. */ @@ -18,7 +15,12 @@ export interface BackendUser { * - local */ auth_method: string; +} +/** + * Reports the Backend's knowledge of a user. + */ +export interface BackendUser extends BackendUserCore { /** * The backend's UUID for this user. */