diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 5d96aa90..a01e009a 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -95,8 +95,8 @@ def query_frontend_user_messages( def query_frontend_user_messages_cursor( auth_method: str, username: str, - lt: Optional[str] = None, - gt: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, only_roots: Optional[bool] = False, include_deleted: Optional[bool] = False, max_count: Optional[int] = Query(10, gt=0, le=1000), @@ -105,8 +105,8 @@ def query_frontend_user_messages_cursor( db: Session = Depends(deps.get_db), ): return get_messages_cursor( - lt=lt, - gt=gt, + before=before, + after=after, auth_method=auth_method, username=username, only_roots=only_roots, diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 1ef1e929..d3d5e1c3 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -50,8 +50,8 @@ def query_messages( @router.get("/cursor", response_model=protocol.MessagePage) def get_messages_cursor( - lt: Optional[str] = None, - gt: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, user_id: Optional[UUID] = None, auth_method: Optional[str] = None, username: Optional[str] = None, @@ -63,6 +63,8 @@ def get_messages_cursor( api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db), ): + assert max_count is not None + def split_cursor(x: str | None) -> tuple[datetime, UUID]: if not x: return None, None @@ -74,11 +76,21 @@ def get_messages_cursor( except ValueError: raise OasstError("Invalid cursor value", OasstErrorCode.INVALID_CURSOR_VALUE) - lte_created_date, lt_id = split_cursor(lt) - gte_created_date, gt_id = split_cursor(gt) + if desc: + gte_created_date, gt_id = split_cursor(before) + lte_created_date, lt_id = split_cursor(after) + query_desc = not (before is not None and not after) + else: + lte_created_date, lt_id = split_cursor(before) + gte_created_date, gt_id = split_cursor(after) + query_desc = before is not None and not after + + print(f"{desc=} {query_desc=} {gte_created_date=} {lte_created_date=}") + + qry_max_count = max_count + 1 if before is None or after is None else max_count pr = PromptRepository(db, api_client) - messages = pr.query_messages_ordered_by_created_date( + items = pr.query_messages_ordered_by_created_date( user_id=user_id, auth_method=auth_method, username=username, @@ -89,22 +101,30 @@ def get_messages_cursor( lt_id=lt_id, only_roots=only_roots, deleted=None if include_deleted else False, - desc=desc, - limit=max_count, + desc=query_desc, + limit=qry_max_count, ) - items = utils.prepare_message_list(messages) + num_rows = len(items) + if qry_max_count > max_count and num_rows == qry_max_count: + assert not (before and after) + items = items[:-1] + + if desc != query_desc: + items.reverse() + + items = utils.prepare_message_list(items) n, p = None, None if len(items) > 0: - if len(items) == max_count or gte_created_date: + if (num_rows > max_count and before) or after: p = str(items[0].id) + "$" + items[0].created_date.isoformat() - if len(items) == max_count or lte_created_date: + if num_rows > max_count or before: n = str(items[-1].id) + "$" + items[-1].created_date.isoformat() else: - if gte_created_date: - p = gte_created_date.isoformat() - if lte_created_date: - n = lte_created_date.isoformat() + if after: + p = lte_created_date.isoformat() if desc else gte_created_date.isoformat() + if before: + n = gte_created_date.isoformat() if desc else lte_created_date.isoformat() order = "desc" if desc else "asc" return protocol.MessagePage(prev=p, next=n, sort_key="created_date", order=order, items=items) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index c7ff9f9c..e4683a76 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -78,8 +78,8 @@ def get_users_ordered_by_display_name( @router.get("/cursor", response_model=protocol.FrontEndUserPage) def get_users_cursor( - lt: Optional[str] = None, - gt: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, sort_key: Optional[str] = Query("username", max_length=32), max_count: Optional[int] = Query(100, gt=0, le=10000), api_client_id: Optional[UUID] = None, @@ -99,8 +99,8 @@ def get_users_cursor( return x, None items: list[protocol.FrontEndUser] - qry_max_count = max_count + 1 if lt is None or gt is None else max_count - desc = lt and not gt + qry_max_count = max_count + 1 if before is None or after is None else max_count + desc = before is not None and not after def get_next_prev(num_rows: int, lt: str | None, gt: str | None, key_fn: Callable[[protocol.FrontEndUser], str]): p, n = None, None @@ -119,7 +119,7 @@ def get_users_cursor( def remove_extra_item(items: list[protocol.FrontEndUser], lt: str | None, gt: str | None): num_rows = len(items) if qry_max_count > max_count and num_rows == qry_max_count: - assert not (lt and gt) + assert not (lt is not None and gt is not None) items = items[:-1] if desc: items.reverse() @@ -127,8 +127,8 @@ def get_users_cursor( n, p = None, None if sort_key == "username": - lte_username, lt_id = split_cursor(lt) - gte_username, gt_id = split_cursor(gt) + lte_username, lt_id = split_cursor(before) + gte_username, gt_id = split_cursor(after) items = get_users_ordered_by_username( api_client_id=api_client_id, gte_username=gte_username, @@ -146,8 +146,8 @@ def get_users_cursor( p, n = get_next_prev(num_rows, lte_username, gte_username, lambda x: x.id) elif sort_key == "display_name": - lte_display_name, lt_id = split_cursor(lt) - gte_display_name, gt_id = split_cursor(gt) + lte_display_name, lt_id = split_cursor(before) + gte_display_name, gt_id = split_cursor(after) items = get_users_ordered_by_display_name( api_client_id=api_client_id, gte_display_name=gte_display_name, @@ -247,8 +247,8 @@ def query_user_messages( @router.get("/{user_id}/messages/cursor", response_model=protocol.MessagePage) def query_user_messages_cursor( user_id: Optional[UUID], - lt: Optional[str] = None, - gt: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, only_roots: Optional[bool] = False, include_deleted: Optional[bool] = False, max_count: Optional[int] = Query(10, gt=0, le=1000), @@ -257,8 +257,8 @@ def query_user_messages_cursor( db: Session = Depends(deps.get_db), ): return get_messages_cursor( - lt=lt, - gt=gt, + before=before, + after=after, user_id=user_id, only_roots=only_roots, include_deleted=include_deleted, diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 2e1e4b30..558ec502 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name): if dataset_name == "debate_sum": train, eval = train_val_dataset(train, val_split=0.2) else: - val_name = "validation" if dataset_name not in ["billsum"] else "test" + val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test" eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 719fa0d6..c96ed576 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -3,7 +3,6 @@ from typing import Optional, Union import numpy as np import torch -from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase @@ -23,15 +22,8 @@ class DialogueDataCollator: flatten_messages = [] label_masks = [] - for feature_one in features: - assert len(feature_one) % 2 == 0, "Number of messages must be even" - # TODO: we should push this to dataset __getitem__ - messages = [ - (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") - + x - + (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "") - for i, x in enumerate(feature_one) - ] + for messages in features: + messages = list(messages) # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation diff --git a/model/supervised_finetuning/custom_datasets/formatting.py b/model/supervised_finetuning/custom_datasets/formatting.py new file mode 100644 index 00000000..a6c1c0d8 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/formatting.py @@ -0,0 +1,5 @@ +QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} + + +def format_pair(pair): + return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1] diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 4a1d83a3..1c823934 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -2,6 +2,7 @@ import json import os from urllib.request import urlopen +from custom_datasets.formatting import format_pair from torch.utils.data import Dataset @@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) class InstructionTuning(Dataset): @@ -101,5 +101,4 @@ class InstructionTuning(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 7d9c7f48..47b1c247 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -7,14 +7,13 @@ import re from urllib.request import urlopen import numpy as np +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") -QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} - def index_squad_v2(example): if len(example["answers"]["text"]): @@ -78,7 +77,7 @@ class QADataset(Dataset): def __getitem__(self, idx): data = self.dataset[idx] - return self.index_fn(data) + return format_pair(self.index_fn(data)) class WebGPT(Dataset): @@ -111,7 +110,7 @@ class WebGPT(Dataset): def __getitem__(self, index): question = self.index2question[index] answer = self.questions[question] - return [question, answer] + return format_pair((question, answer)) class SODA(Dataset): @@ -121,14 +120,14 @@ class SODA(Dataset): def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] - prefix = "{}{}. {}{}".format( - QA_SPECIAL_TOKENS["StartPrefix"], - data["narrative"], - "your name {}".format(play_as), - QA_SPECIAL_TOKENS["EndPrefix"], - ) question, answer = "", "" prefix, postfix = "", "" + dialogue_bg = "{}{} {}{}".format( + QA_SPECIAL_TOKENS["StartPrefix"], + data["narrative"], + "your are {}".format(play_as), + QA_SPECIAL_TOKENS["EndPrefix"], + ) previous_chat = [] for idx, convo in enumerate(data["dialogue"]): @@ -138,14 +137,20 @@ class SODA(Dataset): else: answer = convo postfix = data["speakers"][idx] + if len(question) and len(answer) and prefix != postfix and postfix == play_as: history = "".join( - ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat] + [ + "{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) + for p in previous_chat + ] ) if len(history): history += "" - pairs.append((prefix + history + question, answer)) + prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"] + pairs.append((dialogue_bg + history + prompt, answer)) previous_chat.append((question, answer)) + return pairs def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: @@ -166,8 +171,8 @@ class SODA(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + # special token added during preprocess + return self.pairs[index] class SODADialogue(Dataset): @@ -218,7 +223,7 @@ class SODADialogue(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class JokeExplaination(Dataset): @@ -253,8 +258,7 @@ class JokeExplaination(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) # https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 2a097fe7..85d21a27 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -3,6 +3,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -54,11 +55,12 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): - def __init__(self, dataset, cache_dir, split): + def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default + self.max_words = max_words def __len__(self): return len(self.dataset) @@ -72,4 +74,5 @@ class SummarizationDataset(Dataset): else: prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) - return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary) + context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt]) + return format_pair((context, summary)) diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py index 815ac722..640b8d8d 100644 --- a/model/supervised_finetuning/custom_datasets/toxic_conversation.py +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -4,12 +4,13 @@ """ import random +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset class ProsocialDialogueExplaination(Dataset): - name = "prosocial_explain" + name = "explain_prosocial" TEMPLATE = [ # 0 : reply or sentence of interest, 1 : reason of caution ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), @@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset): return len(self.pairs) def __getitem__(self, idx): - return self.pairs[idx] + return format_pair(self.pairs[idx]) class ProsocialDialogue(Dataset): @@ -58,8 +59,9 @@ class ProsocialDialogue(Dataset): dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: + prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"] for answer in row["rots"]: - self.pairs.append((self.PREFIX + row["context"], answer)) + self.pairs.append((self.PREFIX + prompt, answer)) def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 694d31ce..f9a71a8e 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -8,6 +8,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -82,7 +83,7 @@ class TranslationPair(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class WMT2019(TranslationPair): @@ -99,6 +100,8 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) + if len(self.pairs) > 100000: + break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 3b59f289..8d5ad08f 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] - translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") for dataset_name in translation + others + summarize_base + qa_base: @@ -31,7 +31,6 @@ def test_collate_fn(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] - trains, evals = [], [] for dataset_name in others + qa_base + summarize_base: print(dataset_name) @@ -41,10 +40,10 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - # print(batch.keys()) - # print(tokenizer.decode(batch['input_ids'][0])) - # print('-----') - # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) + print(batch.keys()) + print(tokenizer.decode(batch["input_ids"][0])) + print("-----") + print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]])) assert batch["targets"].shape[1] <= 512 dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 7b6e03b6..f7a0ab15 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -25,6 +25,10 @@ def get_tokenizer(conf): tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"}) elif "codegen" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) + elif "pythia" in conf.model_name: + tokenizer.add_special_tokens( + {"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"} + ) additional_special_tokens = ( [] diff --git a/website/package-lock.json b/website/package-lock.json index 0d38e98c..71177bf5 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -33,6 +33,7 @@ "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", "install": "^0.13.0", + "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", "next-i18next": "^13.0.3", @@ -45,7 +46,6 @@ "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", "react-i18next": "^12.1.4", - "react-icons": "^4.7.1", "react-table": "^7.8.0", "sharp": "^0.31.3", "swr": "^2.0.0", @@ -26726,6 +26726,14 @@ "yallist": "^3.0.2" } }, + "node_modules/lucide-react": { + "version": "0.105.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz", + "integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==", + "peerDependencies": { + "react": "^16.5.1 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/lz-string": { "version": "1.4.4", "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz", @@ -32657,14 +32665,6 @@ } } }, - "node_modules/react-icons": { - "version": "4.7.1", - "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz", - "integrity": "sha512-yHd3oKGMgm7zxo3EA7H2n7vxSoiGmHk5t6Ou4bXsfcgWyhfDKMpyKfhHR6Bjnn63c+YXBLBPUql9H4wPJM6sXw==", - "peerDependencies": { - "react": "*" - } - }, "node_modules/react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", @@ -57914,6 +57914,12 @@ "yallist": "^3.0.2" } }, + "lucide-react": { + "version": "0.105.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.105.0.tgz", + "integrity": "sha512-iHaIkd4Wq6aNIVrFMXt3If8E/+2lnJd4WlCyntoJNIzZ8nWhdSSHWpsw7XM4rlw2319LZ2t4WLdnM8Z0ECDTOQ==", + "requires": {} + }, "lz-string": { "version": "1.4.4", "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.4.4.tgz", @@ -62143,12 +62149,6 @@ "html-parse-stringify": "^3.0.1" } }, - "react-icons": { - "version": "4.7.1", - "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz", - "integrity": "sha512-yHd3oKGMgm7zxo3EA7H2n7vxSoiGmHk5t6Ou4bXsfcgWyhfDKMpyKfhHR6Bjnn63c+YXBLBPUql9H4wPJM6sXw==", - "requires": {} - }, "react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", diff --git a/website/package.json b/website/package.json index e4fb8319..f2499920 100644 --- a/website/package.json +++ b/website/package.json @@ -50,6 +50,7 @@ "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", "install": "^0.13.0", + "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", "next-i18next": "^13.0.3", @@ -62,7 +63,6 @@ "react-feature-flags": "^1.0.0", "react-hook-form": "^7.42.1", "react-i18next": "^12.1.4", - "react-icons": "^4.7.1", "react-table": "^7.8.0", "sharp": "^0.31.3", "swr": "^2.0.0", diff --git a/website/src/components/CallToAction.tsx b/website/src/components/CallToAction.tsx index e374a471..3b132797 100644 --- a/website/src/components/CallToAction.tsx +++ b/website/src/components/CallToAction.tsx @@ -1,9 +1,10 @@ import { Box, Link, Text, useColorMode } from "@chakra-ui/react"; +import { Github } from "lucide-react"; import { useTranslation } from "next-i18next"; import { useId } from "react"; -import { FaDiscord, FaGithub } from "react-icons/fa"; import { Container } from "./Container"; +import { Discord } from "./Icons/Discord"; const CIRCLE_HEIGHT = 558; const CIRCLE_WIDTH = 558; @@ -70,7 +71,7 @@ export function CallToAction() { type="button" className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2" > - + {t("discord")} @@ -81,7 +82,7 @@ export function CallToAction() { type="button" className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2" > - + {t("github")} diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable.tsx index f9ef4e49..466393eb 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable.tsx @@ -25,8 +25,8 @@ import { useDisclosure, } from "@chakra-ui/react"; import { ColumnDef, flexRender, getCoreRowModel, useReactTable } from "@tanstack/react-table"; +import { Filter } from "lucide-react"; import { ChangeEvent, ReactNode } from "react"; -import { FaFilter } from "react-icons/fa"; import { useDebouncedCallback } from "use-debounce"; export type DataTableColumnDef = ColumnDef & { @@ -148,7 +148,7 @@ const FilterModal = ({ diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx index 51e51a00..a9f29bc2 100644 --- a/website/src/components/EmptyState.tsx +++ b/website/src/components/EmptyState.tsx @@ -1,11 +1,10 @@ import { Box, Text, useColorModeValue } from "@chakra-ui/react"; +import { AlertTriangle, LucideIcon } from "lucide-react"; import NextLink from "next/link"; -import { FiAlertTriangle } from "react-icons/fi"; -import { IconType } from "react-icons/lib"; type EmptyStateProps = { text: string; - icon: IconType; + icon: LucideIcon; }; export const EmptyState = (props: EmptyStateProps) => { @@ -25,5 +24,5 @@ export const EmptyState = (props: EmptyStateProps) => { }; export const TaskEmptyState = () => { - return ; + return ; }; diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 7e28f2c2..58c7559f 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -22,8 +22,8 @@ import { } from "@chakra-ui/react"; import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; import clsx from "clsx"; +import { AlertCircle } from "lucide-react"; import { useEffect, useReducer } from "react"; -import { FiAlertCircle } from "react-icons/fi"; import { get, post } from "src/lib/api"; import { colors } from "src/styles/Theme/colors"; import { Message } from "src/types/Conversation"; @@ -154,7 +154,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => { - diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx index 64614578..0d70a442 100644 --- a/website/src/components/Header/Header.tsx +++ b/website/src/components/Header/Header.tsx @@ -1,10 +1,10 @@ import { Box, Button, Flex, Text } from "@chakra-ui/react"; +import { User } from "lucide-react"; import Image from "next/image"; import Link from "next/link"; import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import { Flags } from "react-feature-flags"; -import { FaUser } from "react-icons/fa"; import { LanguageSelector } from "src/components/LanguageSelector"; import { UserMenu } from "./UserMenu"; @@ -17,7 +17,7 @@ function AccountButton() { return ( - diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 912b75d4..8b5de035 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -11,11 +11,11 @@ import { Text, useColorModeValue, } from "@chakra-ui/react"; +import { AlertTriangle, Layout, LogOut, Settings, Shield } from "lucide-react"; import NextLink from "next/link"; import { signOut, useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import React, { ElementType, useCallback } from "react"; -import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi"; interface MenuOption { name: string; @@ -39,19 +39,19 @@ export function UserMenu() { { name: t("dashboard"), href: "/dashboard", - icon: FiLayout, + icon: Layout, isExternal: false, }, { name: t("account_settings"), href: "/account", - icon: FiSettings, + icon: Settings, isExternal: false, }, { name: t("report_a_bug"), href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose", - icon: FiAlertTriangle, + icon: AlertTriangle, isExternal: true, }, ]; @@ -60,7 +60,7 @@ export function UserMenu() { options.unshift({ name: t("admin_dashboard"), href: "/admin", - icon: FiShield, + icon: Shield, isExternal: false, }); } @@ -93,7 +93,7 @@ export function UserMenu() { _hover={{ textDecoration: "none" }} > - @@ -101,7 +101,7 @@ export function UserMenu() { - diff --git a/website/src/components/Icons/Discord.tsx b/website/src/components/Icons/Discord.tsx new file mode 100644 index 00000000..ea1118fb --- /dev/null +++ b/website/src/components/Icons/Discord.tsx @@ -0,0 +1,16 @@ +import { LucideIcon } from "lucide-react"; + +export const Discord: LucideIcon = ({ size = 24, ...rest }) => { + return ( + + + + ); +}; diff --git a/website/src/components/LanguageSelector/LanguageSelector.tsx b/website/src/components/LanguageSelector/LanguageSelector.tsx index 37659265..aea25dcd 100644 --- a/website/src/components/LanguageSelector/LanguageSelector.tsx +++ b/website/src/components/LanguageSelector/LanguageSelector.tsx @@ -8,12 +8,13 @@ const LanguageSelector = () => { const router = useRouter(); const { i18n } = useTranslation(); - const { language: currentLanguage } = i18n; - const languageNames = useMemo(() => { - return new Intl.DisplayNames([currentLanguage], { - type: "language", - }); - }, [currentLanguage]); + // Memo the set of locales and their display names. + const localesAndNames = useMemo(() => { + return router.locales.map((locale) => ({ + locale, + name: new Intl.DisplayNames([locale], { type: "language" }).of(locale), + })); + }, [router.locales]); const languageChanged = useCallback( async (option) => { @@ -25,12 +26,12 @@ const LanguageSelector = () => { [router] ); - const locales = router.locales; + const { language: currentLanguage } = i18n; return ( diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index 55085550..1b5bf430 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -1,8 +1,8 @@ // https://nextjs.org/docs/basic-features/layouts import { Box, Grid } from "@chakra-ui/react"; +import { Activity, BarChart2, Layout, MessageSquare, Users } from "lucide-react"; import type { NextPage } from "next"; -import { FiBarChart2, FiLayout, FiMessageSquare, FiUsers, FiActivity } from "react-icons/fi"; import { Header } from "src/components/Header"; import { SlimFooter } from "./Dashboard/SlimFooter"; @@ -38,19 +38,19 @@ export const getDashboardLayout = (page: React.ReactElement) => ( label: "Dashboard", pathname: "/dashboard", desc: "Dashboard Home", - icon: FiLayout, + icon: Layout, }, { label: "Messages", pathname: "/messages", desc: "Messages Dashboard", - icon: FiMessageSquare, + icon: MessageSquare, }, { label: "Leaderboard", pathname: "/leaderboard", desc: "User Leaderboard", - icon: FiBarChart2, + icon: BarChart2, }, ]} > @@ -73,13 +73,13 @@ export const getAdminLayout = (page: React.ReactElement) => ( label: "Users", pathname: "/admin", desc: "Users Dashboard", - icon: FiUsers, + icon: Users, }, { label: "Status", pathname: "/admin/status", desc: "Status Dashboard", - icon: FiActivity, + icon: Activity, }, ]} > diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx index ed98752c..acf92e05 100644 --- a/website/src/components/Messages/MessageTable.tsx +++ b/website/src/components/Messages/MessageTable.tsx @@ -5,13 +5,19 @@ import { Message } from "src/types/Conversation"; interface MessageTableProps { messages: Message[]; enableLink?: boolean; + highlightLastMessage?: boolean; } -export function MessageTable({ messages, enableLink }: MessageTableProps) { +export function MessageTable({ messages, enableLink, highlightLastMessage }: MessageTableProps) { return ( - {messages.map((item) => ( - + {messages.map((item, idx) => ( + ))} ); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 1205991e..77202c44 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -1,14 +1,15 @@ -import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; +import { Avatar, Box, HStack, useBreakpointValue, useColorModeValue } from "@chakra-ui/react"; import { boolean } from "boolean"; -import Link from "next/link"; import { useRouter } from "next/router"; import { useCallback, useMemo } from "react"; import { FlaggableElement } from "src/components/FlaggableElement"; import { Message } from "src/types/Conversation"; +import { colors } from "styles/Theme/colors"; interface MessageTableEntryProps { item: Message; enabled?: boolean; + highlight?: boolean; } export function MessageTableEntry(props: MessageTableEntryProps) { @@ -37,6 +38,7 @@ export function MessageTableEntry(props: MessageTableEntryProps) { ), [borderColor, inlineAvatar, item.is_assistant] ); + const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight); return ( @@ -48,6 +50,8 @@ export function MessageTableEntry(props: MessageTableEntryProps) { p="4" borderRadius="md" bg={item.is_assistant ? backgroundColor : backgroundColor2} + outline={props.highlight && "2px solid black"} + outlineColor={highlightColor} onClick={props.enabled && goToMessage} _hover={props.enabled && { cursor: "pointer", opacity: 0.9 }} whiteSpace="pre-wrap" diff --git a/website/src/components/SideMenu.tsx b/website/src/components/SideMenu.tsx index 3722eaa8..10e83ce4 100644 --- a/website/src/components/SideMenu.tsx +++ b/website/src/components/SideMenu.tsx @@ -1,15 +1,14 @@ import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react"; +import { LucideIcon, Sun } from "lucide-react"; import Link from "next/link"; import { useRouter } from "next/router"; -import { FiSun } from "react-icons/fi"; -import { IconType } from "react-icons/lib"; import { colors } from "styles/Theme/colors"; export interface MenuButtonOption { label: string; pathname: string; desc: string; - icon: IconType; + icon: LucideIcon; } export interface SideMenuProps { @@ -47,7 +46,7 @@ export function SideMenu(props: SideMenuProps) { bg={router.pathname === item.pathname ? "blue.500" : null} _hover={router.pathname === item.pathname ? { bg: "blue.600" } : null} > - +