diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index f2119e3b..eb244d3b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -947,6 +947,9 @@ class PromptRepository: if deleted is not None: qry = qry.filter(Message.deleted == deleted) + if lang is not None: + qry = qry.filter(Message.lang == lang) + if desc: qry = qry.order_by(Message.created_date.desc(), Message.id.desc()) else: @@ -955,9 +958,6 @@ class PromptRepository: if limit is not None: qry = qry.limit(limit) - if lang is not None: - qry = qry.filter(Message.lang == lang) - return self._add_user_emojis_all(qry) def update_children_counts(self, message_tree_id: UUID): diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index d5b10e01..387e91e4 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -1,62 +1,18 @@ # Train using supervised examples -Requirements +## Requirements -``` -wandb -evaluate -datasets -transformers -torch -``` +`pip install -r requirements.txt` -Start training reward model +Start training SFT model ```bash -python trainer.py --configs defaults galactica-125 +python trainer.py --configs defaults galactica-125m ``` -## Dataset - -For now we only support webgpt and summary dataset from OpenAI. Once -open-asisstant dataset are available it will be added here. - -## Model - -Normally you should be able to add new models in configs/config.yml - -``` -your-model-name: - learning_rate: 2e-6 - model_name: - weight_decay: 0.01 - max_length: 812 - warmup_steps: 600 - gradient_checkpointing: false - gradient_accumulation_steps: 5 - per_device_train_batch_size: 4 - per_device_eval_batch_size: 4 -``` - -``` -python trainer.py --configs defaults your-model-name -``` - -However, if the model of your choice doesn't have pad_token, eos_token, -sep_token, you have to update utils.py `get_tokenizer` to use the right token. - -## Deepspeed support - -You can edit the configs/zero_config.json and use any stage you wish. The -current config uses zero-stage 3. For more details on how to setup the config -checkout [this page](https://www.deepspeed.ai/tutorials/zero/) - -Once you are satisfy with your deepzero config, you can add --deepspeed flag at -the end to trigger deepspeed - -``` -python trainer.py --configs defaults your-model-name --deepspeed -``` +For `wandb`: update the `entity` argument in `trainer.py`'s call to `wandb.init` +to be your weights and biases username per +[docs](https://docs.wandb.ai/ref/python/init). ## Dataset choices @@ -80,6 +36,74 @@ Currently only these languages are supported via prompt translation: ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh ``` +## Dataset sub-sampling + +We can subsample the **training** data by passing either the `fraction` or +`size` argument in the `configs/config.yml` file. Don't forget the additional +colon ":" after the dataset name when doing this. + +Example: + +``` + datasets: + - webgpt: + fraction : 0.05 + - prompt_dialogue: + size : 500 + - adversarial_qa + - trivia_qa_nocontext +``` + +In this example, per epoch we will use: + +- A random 5% of `webgpt`; +- A random 500 examples from `prompt_dialogue`; +- All examples from datasets for which we don't specify the `fraction` or `size` + argument. + +In the above example, per epoch we'll use a different 5% from `webgpt` and a +different 500 examples from `prompt_dialogue`. + +This works with `torch.distributed`. + +## Model + +Normally you should be able to add new models in `configs/config.yml` + +``` +your-model-name: + learning_rate: 2e-6 + model_name: + weight_decay: 0.01 + max_length: 812 + warmup_steps: 600 + gradient_checkpointing: false + gradient_accumulation_steps: 5 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 +``` + +``` +python trainer.py --configs defaults your-model-name +``` + +However, if the model of your choice doesn't have `pad_token`, `eos_token`, +`sep_token`, you have to update `get_tokenizer` in `utils.py` to use the right +token. + +## Deepspeed support + +You can edit the configs/zero_config.json and use any stage you wish. The +current config uses zero-stage 3. For more details on how to setup the config +checkout [this page](https://www.deepspeed.ai/tutorials/zero/). + +Once you are satisfy with your deepzero config, you can add --deepspeed flag at +the end to trigger deepspeed + +``` +python trainer.py --configs defaults your-model-name --deepspeed +``` + ## Results Experimental results in wandb @@ -87,7 +111,7 @@ Experimental results in wandb ## TODOS -- decide on a model +- Decide on a model - Merge utils etc with reward model - Casual Modelling for GPT-JT does not leverage the bidirectional mask for the prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index d70fad41..79e4751d 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -17,7 +17,7 @@ defaults: freeze_layer: datasets: - webgpt - - prompt_dialogue + # - prompt_dialogue - squad_v2 - adversarial_qa - trivia_qa_nocontext diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 5faa22e6..092b7743 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -219,7 +219,7 @@ class SODA(Dataset): return pairs - def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: + def __init__(self, cache_dir, input_max_length=1024) -> None: super().__init__() self.pairs = [] @@ -230,9 +230,6 @@ class SODA(Dataset): if len(prompt) < input_max_length: self.pairs.append((prompt, answer)) - if len(self.pairs) > max_sample_size: - break - def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index f9a71a8e..008de751 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -100,8 +100,6 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) - if len(self.pairs) > 100000: - break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 8f8cc63c..95e5a472 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,6 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr @@ -12,3 +11,4 @@ PyYAML>=6.0 scikit_learn==1.2.0 torch>=1.11.0 transformers==4.25.1 +wandb diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 83034d95..fb5d4bee 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -9,7 +9,7 @@ from efficiency_utils import fuse_gelu from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -31,6 +31,7 @@ class SFTTrainer(Trainer): self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, + sampler: torch.utils.data.sampler.Sampler = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, **kwargs, @@ -39,6 +40,7 @@ class SFTTrainer(Trainer): # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) + self.sampler = sampler def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -89,6 +91,32 @@ class SFTTrainer(Trainer): return (loss, logits, labels) + def get_train_dataloader(self): + """Inject custom data sampling behaviour into training loop""" + if self.sampler is None: + torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=self.data_collator, + ) + else: + dataloader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + sampler=self.sampler, + collate_fn=self.data_collator, + ) + if torch.cuda.device_count() <= 1: + return dataloader + else: + # Not strictly necessary to use accelerate, currently just + # ensures batches are padded to be divisible by # devices + from accelerate import Accelerator + + accelerator = Accelerator() + return accelerator.prepare(dataloader) + def _strtobool(x): return bool(strtobool(x)) @@ -142,8 +170,8 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF if training_conf.quantization: @@ -181,7 +209,6 @@ if __name__ == "__main__": ) assert len(evals) > 0 - if not training_conf.deepspeed or training_conf.local_rank == 0: import wandb @@ -192,8 +219,9 @@ if __name__ == "__main__": ) trainer = SFTTrainer( - model, - args, + model=model, + args=args, + sampler=sampler, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index c3b8264f..380bca8e 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,11 +1,8 @@ -# from functools import partial +import random from pathlib import Path -from typing import NamedTuple +from typing import List, NamedTuple import evaluate - -# import nltk -# import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -15,6 +12,79 @@ from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset +from torch.utils.data.sampler import Sampler + + +class PerDatasetSampler(Sampler): + """Sampler which returns a fixed number of samples per dataset, per epoch. + + Example: + + Dataset 1 has 10,000 examples and we want 200 per epoch + Dataset 2 has 500 examples and we want all 500 per epoch + + Epoch size will be 700 and every epoch we'll sample a different + 200 from dataset 1. + + Parameters + ---------- + dataset_sizes : List[int] + A list with the size of each dataset. + dataset_size_per_epoch : List[int] + How many examples to get from each dataset per epoch. + + Note: dataset_sizes & dataset_size_per_epoch must be in the same order. + Further the examples in the underlying torch.utils.data.Dataset + must per ordered as dataset_1, dataset_2, ..., dataset_n. This is fine + if we concatenate a bunch of datasets together + e.g. using torch.utils.data.ConcatDataset which is current behaviour. + """ + + def __init__(self, dataset_sizes: List[int], dataset_size_per_epoch: List[int]): + self.dataset_sizes = dataset_sizes + self.dataset_size_per_epoch = dataset_size_per_epoch + self.num_datasets = len(dataset_sizes) + + def __iter__(self): + epoch_idx = [] + n = 0 + for i in range(self.num_datasets): + sampled_idx = random.sample(range(n, self.dataset_sizes[i] + n), self.dataset_size_per_epoch[i]) + n += self.dataset_sizes[i] + epoch_idx.extend(sampled_idx) + random.shuffle(epoch_idx) + return iter(epoch_idx) + + def __len__(self): + return int(sum(self.dataset_size_per_epoch)) + + @classmethod + def build_sampler_from_config(cls, training_conf, datasets): + dataset_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)] + return cls(dataset_sizes, dataset_size_per_epoch) + + +def get_dataset_fractions(conf, dataset_sizes): + """Calculate fraction of each dataset to use per epoch when subsampling""" + fractions = [] + for i, data_config in enumerate(conf): + dataset_name = get_dataset_name_from_data_config(data_config) + if isinstance(data_config, dict): + if "fraction" in data_config[dataset_name]: + if data_config[dataset_name]["fraction"] <= 0: + raise ValueError("Please specify fraction as a value between 0 < fraction <= 1") + fractions.append(min(1, data_config[dataset_name]["fraction"])) + elif "size" in data_config[dataset_name]: + if data_config[dataset_name]["size"] > dataset_sizes[i]: + raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}") + fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) + else: + raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.") + else: + fractions.append(1) + return fractions class SpecialTokens(NamedTuple): @@ -140,10 +210,17 @@ def get_model(conf, tokenizer): return model +def get_dataset_name_from_data_config(data_config): + if isinstance(data_config, dict): + return list(data_config.keys())[0] + return data_config + + def get_dataset(conf, tokenizer): train_datasets, evals = [], {} - for dataset_name in conf.datasets: + for data_config in conf.datasets: + dataset_name = get_dataset_name_from_data_config(data_config) train, val = get_one_dataset(conf, dataset_name) train_datasets.append(train) evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val diff --git a/website/public/locales/da/dashboard.json b/website/public/locales/da/dashboard.json new file mode 100644 index 00000000..79c790ef --- /dev/null +++ b/website/public/locales/da/dashboard.json @@ -0,0 +1,8 @@ +{ + "grab_a_task": "Tag en opgave!", + "create": "Lav", + "evaluate": "Evaluer", + "label": "Label", + "dashboard": "Dashboard", + "go": "Start" +} diff --git a/website/public/locales/en/message.json b/website/public/locales/en/message.json index a531f0a4..75565656 100644 --- a/website/public/locales/en/message.json +++ b/website/public/locales/en/message.json @@ -8,7 +8,7 @@ "open_new_tab_action": "Open in new tab", "parent": "Parent", "reactions": "Reactions", - "recent_messages": "Recent Messages", + "recent_messages": "Recent Messages in {{language}}", "report_action": "Report", "report_placeholder": "Why should this message be reviewed?", "report_title": "Report", diff --git a/website/public/locales/vi/common.json b/website/public/locales/vi/common.json index 71309549..aa2cc855 100644 --- a/website/public/locales/vi/common.json +++ b/website/public/locales/vi/common.json @@ -3,21 +3,17 @@ "account_settings": "Tài khoản", "admin_dashboard": "Trang cho admin", "connect": "Liên hệ", - "conversational": "Chatbot AI cho tất cả mọi người", - "copied": "Copied", + "conversational": "Chatbot AI cho mọi người", + "copied": "Đã sao chép", "dark_mode": "Giao diện tối", - "dashboard_home": "Trang chính", "dashboard": "Trang chính", "delete": "Xoá", "discord": "Discord", "docs": "Hướng dẫn", "github": "GitHub", - "leaderboard": "Bảng xếp hạng", "legal": "Luật lệ", "light_mode": "Giao diện sáng", "loading": "Đang tải...", - "messages_dashboard": "Bảng xếp hạng tin nhắn", - "messages": "Tin nhắn", "more_information": "Xem thêm", "no": "Không", "privacy_policy": "Chính sách bảo mật", @@ -26,11 +22,8 @@ "sign_out": "Đăng xuất", "status_dashboard": "Trang hiện tình trạng", "status": "Tình trạng", - "success": "Success", + "success": "Thành công", "terms_of_service": "Điều khoản sử dụng", "title": "Open Assistant", - "user_leaderboard": "Bảng xếp hạng người dùng", - "users_dashboard": "Trang về người dùng", - "users": "Người dùng", "yes": "Có" } diff --git a/website/public/locales/vi/index.json b/website/public/locales/vi/index.json index 43022145..1c6b5e19 100644 --- a/website/public/locales/vi/index.json +++ b/website/public/locales/vi/index.json @@ -1,15 +1,15 @@ { "blurb": "Đây sẽ là cuộc cách mạng công nghệ mới.", - "blurb1": "Giống như cách Stable Diffusion đã cho mọi người công cụ để làm tranh ảnh bằng AI, Open Assistant sẽ làm như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.", - "description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện viên trên toàn thế giới.", + "blurb1": "Giống như Stable Diffusion với tranh ảnh bằng AI, Open Assistant sẽ làm tương tự như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.", + "description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện.", "faq_items": { "q0": "Open Assistant bây giờ thế nào rồi?", "a0": "Dự án này đang trong giai đoạn phát triển, từ những nghiên cứu về sử dụng RLHF (học từ phản hồi con người) trong các mô hình ngôn ngữ lớn.", "q1": "Open Assistant được phát triển bởi ai?", - "a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên trên toàn thế giới." + "a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên." }, "faq_title": "Câu hỏi", "join_us_description": "Các dự án mã nguồn mở được phát triển bởi những người như bạn. Triết lý mã nguồn mở là hợp tác để tạo và phát triển công nghệ mới mà làm giàu thế giới quanh ta. Bạn có muốn tham gia không? Liên hệ chúng tôi ở đây:", "join_us_title": "Tham gia", - "subtitle": "Chatbot AI cho tất cả mọi người" + "subtitle": "Chatbot AI cho mọi người" } diff --git a/website/public/locales/vi/leaderboard.json b/website/public/locales/vi/leaderboard.json index ebe25dda..d5b4e167 100644 --- a/website/public/locales/vi/leaderboard.json +++ b/website/public/locales/vi/leaderboard.json @@ -1,15 +1,15 @@ { "daily": "Ngày", - "label": "Nhãn", + "label": "Số nhãn", "last_updated_at": "Cập nhật lần cuối: {{val, datetime}}", "leaderboard": "Bảng xếp hạng", "monthly": "Tháng", "next": "Tiếp", "overall": "Tổng quan", "previous": "Trước", - "prompt": "Câu đầu tiên", + "prompt": "Số câu đầu", "rank": "Xếp hạng", - "reply": "Câu trả lời", + "reply": "Số câu trả lời", "score": "Điểm", "top_5_contributors_today": "Top 5 người đóng góp", "user": "Tên người dùng", diff --git a/website/public/locales/vi/message.json b/website/public/locales/vi/message.json index baabe4b6..45d53060 100644 --- a/website/public/locales/vi/message.json +++ b/website/public/locales/vi/message.json @@ -1,20 +1,20 @@ { - "copy_message_id": "Copy message ID", + "copy_message_id": "Sao chép ID", "label_action": "Nhãn", "label_title": "Nhãn", "message": "Tin nhắn", - "message_deleted": "Message deleted", + "message_deleted": "Tin nhắn đã xoá", "open_new_tab_action": "Mở ở trang mới", "parent": "Tin nhắn gốc", "reactions": "Bình luận", "recent_messages": "Tin nhắn gần đây", "report_action": "Báo cáo", - "report_placeholder": "Tại sao tin nhắn này cần được báo cáo?", + "report_placeholder": "Nêu lý do để báo cáo", "report_title": "Báo cáo", "send_report": "Gửi", - "stop_tree": "Stop tree", + "stop_tree": "Dừng nhánh tin nhắn", "submit_labels": "Gửi", - "tree_stopped": "Tree stopped {{id}}", + "tree_stopped": "Nhánh tin nhắn {{id}} đã dừng", "view_user": "Xem người dùng", "your_recent_messages": "Tin nhắn gần đây của bạn" } diff --git a/website/public/locales/vi/side_menu.json b/website/public/locales/vi/side_menu.json new file mode 100644 index 00000000..6b06846f --- /dev/null +++ b/website/public/locales/vi/side_menu.json @@ -0,0 +1,12 @@ +{ + "dashboard": "Trang chính", + "dashboard_home": "Trang chính", + "leaderboard": "Bảng xếp hạng", + "messages": "Tin nhắn", + "messages_dashboard": "Bảng xếp hạng tin nhắn", + "status": "Tình trạng", + "status_dashboard": "Trang hiện tình trạng", + "user_leaderboard": "Bảng xếp hạng người dùng", + "users": "Người dùng", + "users_dashboard": "Trang về người dùng" +} diff --git a/website/public/locales/vi/tasks.json b/website/public/locales/vi/tasks.json index e197c410..7884f5fe 100644 --- a/website/public/locales/vi/tasks.json +++ b/website/public/locales/vi/tasks.json @@ -1,22 +1,22 @@ { "available_task_count": "{{count}} việc", "classify_assistant_reply": { - "label": "Phân loại tin nhắn của Open Assistant", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "label": "Phân loại các tin nhắn của Open Assistant", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện." }, "classify_initial_prompt": { - "label": "Phân loại tin nhắn đầu", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "label": "Phân loại các tin nhắn đầu tiên", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Đọc tin nhắn đầu và trả lời các câu hỏi." }, "classify_prompter_reply": { "label": "Phân loại tin nhắn người dùng", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện." }, "create_initial_prompt": { - "label": "Tạo tin nhắn đầu", + "label": "Tạo tin nhắn đầu tiên", "desc": "Viết tin nhắn đầu tiên để làm bộ dữ liệu cho Open Assistant.", "overview": "Viết tin nhắn đầu tiên để Open Assistant trả lời", "instruction": "Viết tin nhắn đầu", @@ -27,17 +27,17 @@ "unchanged_message": "Are you sure you would like to continue?" }, "label_assistant_reply": { - "label": "Tạo nhãn cho tin nhắn của Open Assistant", + "label": "Đánh giá các tin nhắn của Open Assistant", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của Open Assistant.", "overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau." }, "label_initial_prompt": { - "label": "Tạo nhãn cho tin nhắn đầu", + "label": "Đánh giá các tin nhắn đầu tiên", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn đầu.", "overview": "Tạo nhãn dữ liệu cho tin nhắn sau." }, "label_prompter_reply": { - "label": "Tạo nhãn cho tin nhắn người dùng", + "label": "Đánh giá các tin nhắn người dùng", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của người dùng.", "overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau." }, @@ -46,28 +46,28 @@ "desc": "Giúp cải thiện Open Assistant bằng cách làm một việc ngẫu nhiên." }, "rank_assistant_replies": { - "label": "Xếp hạng câu trả lời của Open Assistant", + "label": "Xếp hạng các câu trả lời của Open Assistant", "desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời mà Open Assistant đưa ra.", "overview": "Từ những câu trả lời của Open Assistant, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "rank_initial_prompts": { - "label": "Xếp hạng tin nhắn đầu tiên", + "label": "Xếp hạng các tin nhắn đầu tiên", "desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời của tin nhắn đầu tiên.", "overview": "Từ những tin nhắn đầu sau, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "rank_user_replies": { - "label": "Xếp hạng câu trả lời của người dùng", + "label": "Xếp hạng các câu trả lời của người dùng", "desc": "Giúp cải thiện câu trả lời của Open Assistant.", "overview": "Từ những câu trả lời của người dùng, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "reply_as_assistant": { - "label": "Đóng vai Open Assistant", + "label": "Đóng vai trợ lý", "desc": "Giúp cải thiện câu trả lời của Open Assistant.", "overview": "Tạo câu trả lời phù hợp cho cuộc trò truyện dưới đây", "response_placeholder": "Viết vào đây..." diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable/DataTable.tsx similarity index 90% rename from website/src/components/DataTable.tsx rename to website/src/components/DataTable/DataTable.tsx index 35246bf7..3e75f32e 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable/DataTable.tsx @@ -23,13 +23,23 @@ import { Tr, useDisclosure, } from "@chakra-ui/react"; -import { Cell, ColumnDef, flexRender, getCoreRowModel, Row, useReactTable } from "@tanstack/react-table"; +import { + Cell, + ColumnDef, + ExpandedState, + flexRender, + getCoreRowModel, + getExpandedRowModel, + Row, + useReactTable, +} from "@tanstack/react-table"; import { Filter } from "lucide-react"; import { useTranslation } from "next-i18next"; -import { ChangeEvent, ReactNode } from "react"; +import { ChangeEvent, ReactNode, useState } from "react"; import { useDebouncedCallback } from "use-debounce"; -export type DataTableColumnDef = ColumnDef & { +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type DataTableColumnDef = ColumnDef & { filterable?: boolean; span?: number | ((cell: Cell) => number | undefined); }; @@ -54,6 +64,7 @@ export type DataTableProps = { disablePrevious?: boolean; disablePagination?: boolean; rowProps?: TableRowProps | DataTableRowPropsCallback; + getSubRows?: (row: T) => T[] | undefined; }; export const DataTable = ({ @@ -68,12 +79,21 @@ export const DataTable = ({ disablePrevious, disablePagination, rowProps, + getSubRows, }: DataTableProps) => { const { t } = useTranslation("leaderboard"); + const [expanded, setExpanded] = useState({}); + const { getHeaderGroups, getRowModel } = useReactTable({ data, columns, getCoreRowModel: getCoreRowModel(), + getExpandedRowModel: getExpandedRowModel(), + state: { + expanded, + }, + getSubRows, + onExpandedChange: setExpanded, }); const handleFilterChange = (value: FilterItem) => { diff --git a/website/src/components/DataTable/jsonExpandRowModel.tsx b/website/src/components/DataTable/jsonExpandRowModel.tsx new file mode 100644 index 00000000..84c70e20 --- /dev/null +++ b/website/src/components/DataTable/jsonExpandRowModel.tsx @@ -0,0 +1,60 @@ +import { Card, CardBody, Flex } from "@chakra-ui/react"; +import { Cell, CellContext } from "@tanstack/react-table"; +import { ChevronDown, ChevronRight } from "lucide-react"; + +type ExpandableRow = Omit & { + shouldExpand?: boolean; +}; + +export const createJsonExpandRowModel = () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const renderCell = ({ row, getValue }: CellContext, any>) => { + if (!row.original.shouldExpand) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { shouldExpand, ...res } = row.original; + return ( + + +
{JSON.stringify(res, null, 2)}
+
+
+ ); + } + + return ( + + {row.getCanExpand() ? ( + + ) : null}{" "} + {getValue()} + + ); + }; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const span = (cell: Cell, any>) => + cell.row.original.shouldExpand ? undefined : cell.row.getVisibleCells().length; + + const getSubRows = (row: ExpandableRow) => + row.shouldExpand + ? [ + { + ...row, + shouldExpand: false, + }, + ] + : undefined; + + const toExpandable = function (arr: T[] | undefined, val = true): ExpandableRow[] { + return !arr ? [] : arr.map((element) => ({ ...element, shouldExpand: val })); + }; + + return { renderCell, span, getSubRows, toExpandable }; +}; diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 8b5de035..cad2dbe8 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -69,7 +69,7 @@ export function UserMenu() { - + {session.user.name || "New User"} diff --git a/website/src/components/LeaderboardTable/LeaderboardTable.tsx b/website/src/components/LeaderboardTable/LeaderboardTable.tsx index d59fc902..fc3ae099 100644 --- a/website/src/components/LeaderboardTable/LeaderboardTable.tsx +++ b/website/src/components/LeaderboardTable/LeaderboardTable.tsx @@ -2,19 +2,20 @@ import { Box, CircularProgress, Flex, Link, useColorModeValue } from "@chakra-ui import { createColumnHelper } from "@tanstack/react-table"; import { MoreHorizontal } from "lucide-react"; import NextLink from "next/link"; -import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import React, { useMemo } from "react"; +import { useHasRole } from "src/hooks/auth/useHasRole"; import { LeaderboardEntity, LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; -import { DataTable, DataTableColumnDef } from "../DataTable"; +import { DataTable, DataTableColumnDef } from "../DataTable/DataTable"; +import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel"; import { useBoardPagination } from "./useBoardPagination"; import { useBoardRowProps } from "./useBoardRowProps"; import { useFetchBoard } from "./useFetchBoard"; type WindowLeaderboardEntity = LeaderboardEntity & { isSpaceRow?: boolean }; const columnHelper = createColumnHelper(); - +const jsonExpandRowModel = createJsonExpandRowModel(); /** * Presents a grid of leaderboard entries with more detailed information. */ @@ -39,17 +40,24 @@ export const LeaderboardTable = ({ } = useFetchBoard( `/api/leaderboard?time_frame=${timeFrame}&limit=${limit}&includeUserStats=${!hideCurrentUserRanking}` ); - const { data: session } = useSession(); - const isAdmin = session?.user?.role === "admin"; + const isAdmin = useHasRole("admin"); + const columns: DataTableColumnDef[] = useMemo( () => [ { ...columnHelper.accessor("rank", { header: t("rank"), - cell: ({ row, getValue }) => (row.original.isSpaceRow ? : getValue()), + cell: (ctx) => + ctx.row.original.isSpaceRow ? ( + + ) : isAdmin ? ( + jsonExpandRowModel.renderCell(ctx) + ) : ( + ctx.getValue() + ), }), - span: (cell) => (cell.row.original.isSpaceRow ? 6 : undefined), + span: (cell) => (cell.row.original.isSpaceRow ? 6 : jsonExpandRowModel.span(cell)), }, columnHelper.accessor("display_name", { header: t("user"), @@ -82,17 +90,17 @@ export const LeaderboardTable = ({ data: paginatedData, end, ...pagnationProps - } = useBoardPagination({ rowPerPage, data: reply?.leaderboard, limit }); - const data: WindowLeaderboardEntity[] = useMemo(() => { - if (hideCurrentUserRanking || !reply?.user_stats_window) { + } = useBoardPagination({ rowPerPage, data: jsonExpandRowModel.toExpandable(reply?.leaderboard || []), limit }); + const data = useMemo(() => { + if (hideCurrentUserRanking || !reply?.user_stats_window || reply.user_stats_window.length === 0) { return paginatedData; } - const userStatsWindow: WindowLeaderboardEntity[] = reply.user_stats_window; + const userStatsWindow: WindowLeaderboardEntity[] = jsonExpandRowModel.toExpandable(reply.user_stats_window); const userStats = userStatsWindow.find((stats) => stats.highlighted); if (userStats && userStats.rank > end) { paginatedData.push( { isSpaceRow: true } as WindowLeaderboardEntity, - ...reply.user_stats_window.filter( + ...userStatsWindow.filter( (stats) => paginatedData.findIndex((leaderBoardEntity) => leaderBoardEntity.user_id === stats.user_id) === -1 ) // filter to avoid duplicated row ); @@ -116,6 +124,7 @@ export const LeaderboardTable = ({ columns={columns} caption={lastUpdated} rowProps={rowProps} + getSubRows={jsonExpandRowModel.getSubRows} {...pagnationProps} > ); diff --git a/website/src/components/LeaderboardTable/TrollboardTable.tsx b/website/src/components/LeaderboardTable/TrollboardTable.tsx index 1b1ea118..0c971655 100644 --- a/website/src/components/LeaderboardTable/TrollboardTable.tsx +++ b/website/src/components/LeaderboardTable/TrollboardTable.tsx @@ -1,26 +1,42 @@ -import { Box, CircularProgress, Flex, Link } from "@chakra-ui/react"; +import { Box, CircularProgress, Flex, IconButton, Link, Tooltip } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; -import { ThumbsDown, ThumbsUp } from "lucide-react"; +import { Mail, ThumbsDown, ThumbsUp, User } from "lucide-react"; import NextLink from "next/link"; import { FetchTrollBoardResponse, TrollboardEntity, TrollboardTimeFrame } from "src/types/Trollboard"; -import { DataTable } from "../DataTable"; +import { DataTable, DataTableColumnDef } from "../DataTable/DataTable"; +import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel"; +import { Discord } from "../Icons/Discord"; import { useBoardPagination } from "./useBoardPagination"; import { useBoardRowProps } from "./useBoardRowProps"; import { useFetchBoard } from "./useFetchBoard"; + const columnHelper = createColumnHelper(); - const toPercentage = (num: number) => `${Math.round(num * 100)}%`; +const jsonExpandRowModel = createJsonExpandRowModel(); -const columns = [ - columnHelper.accessor("rank", {}), +const columns: DataTableColumnDef[] = [ + { + ...columnHelper.accessor("rank", { + cell: jsonExpandRowModel.renderCell, + }), + span: jsonExpandRowModel.span, + }, columnHelper.accessor("display_name", { header: "Display name", - cell: ({ getValue, row }) => ( - - {getValue()} - - ), + cell: ({ getValue, row }) => { + const isEmail = row.original.auth_method === "local"; + return ( + + + {getValue()} + + + {isEmail ? : } + + + ); + }, }), columnHelper.accessor("troll_score", { header: "Troll score", @@ -45,36 +61,19 @@ const columns = [ columnHelper.accessor((row) => row.spam + row.spam_prompts, { header: "Spam", }), - columnHelper.accessor("lang_mismach", { - header: "Lang mismach", - }), - columnHelper.accessor("not_appropriate", { - header: "Not appropriate", - }), - columnHelper.accessor("pii", {}), - columnHelper.accessor("hate_speech", { - header: "Hate speech", - }), - columnHelper.accessor("sexual_content", { - header: "Sexual Content", - }), - columnHelper.accessor("political_content", { - header: "Political Content", - }), - columnHelper.accessor("quality", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("helpfulness", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("humor", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("violence", { - cell: ({ getValue }) => toPercentage(getValue()), - }), columnHelper.accessor("toxicity", { - cell: ({ getValue }) => toPercentage(getValue()), + cell: ({ getValue }) => toPercentage(getValue() || 0), + }), + columnHelper.accessor((row) => row.user_id, { + header: "Actions", + cell: ({ row }) => ( + } + > + ), }), ]; @@ -94,7 +93,11 @@ export const TrollboardTable = ({ lastUpdated, } = useFetchBoard(`/api/admin/trollboard?time_frame=${timeFrame}&limit=${limit}`); - const { data, ...paginationProps } = useBoardPagination({ rowPerPage, data: trollboardRes?.trollboard, limit }); + const { data, ...paginationProps } = useBoardPagination({ + rowPerPage, + data: jsonExpandRowModel.toExpandable(trollboardRes?.trollboard), + limit, + }); const rowProps = useBoardRowProps(); if (isLoading) { return ; @@ -112,11 +115,12 @@ export const TrollboardTable = ({ }, }} > - + diff --git a/website/src/components/LeaderboardTable/useBoardRowProps.ts b/website/src/components/LeaderboardTable/useBoardRowProps.ts index 32f3fc56..be0fe7de 100644 --- a/website/src/components/LeaderboardTable/useBoardRowProps.ts +++ b/website/src/components/LeaderboardTable/useBoardRowProps.ts @@ -2,7 +2,7 @@ import { useColorModeValue, useToken } from "@chakra-ui/react"; import { useCallback } from "react"; import { colors } from "src/styles/Theme/colors"; -import { DataTableRowPropsCallback } from "../DataTable"; +import { DataTableRowPropsCallback } from "../DataTable/DataTable"; export const useBoardRowProps = () => { const borderColor = useToken("colors", useColorModeValue(colors.light.active, colors.dark.active)); diff --git a/website/src/components/Messages/MessageEmojiButton.stories.tsx b/website/src/components/Messages/MessageEmojiButton.stories.tsx index b083c966..5d3e8be6 100644 --- a/website/src/components/Messages/MessageEmojiButton.stories.tsx +++ b/website/src/components/Messages/MessageEmojiButton.stories.tsx @@ -11,17 +11,16 @@ export default { const Template = ({ emoji, count, - checked, - showCount, + ...rest }: { emoji: string; count: number; checked?: boolean; - showCount: boolean; + userIsAuthor: boolean; + disabled?: boolean; + userReacted: boolean; }) => { - return ( - - ); + return ; }; export const Default = Template.bind({}); @@ -29,7 +28,9 @@ Default.args = { emoji: "+1", count: 7, checked: false, - showCount: true, + userIsAuthor: false, + disabled: false, + userReacted: true, }; export const BigNumber = Template.bind({}); diff --git a/website/src/components/Messages/MessageEmojiButton.tsx b/website/src/components/Messages/MessageEmojiButton.tsx index f140a789..e3acb3c0 100644 --- a/website/src/components/Messages/MessageEmojiButton.tsx +++ b/website/src/components/Messages/MessageEmojiButton.tsx @@ -1,4 +1,5 @@ import { Button } from "@chakra-ui/react"; +import { useHasRole } from "src/hooks/auth/useHasRole"; import { MessageEmoji } from "src/types/Conversation"; import { emojiIcons } from "src/types/Emoji"; @@ -6,12 +7,27 @@ interface MessageEmojiButtonProps { emoji: MessageEmoji; checked?: boolean; onClick: () => void; - showCount: boolean; + userIsAuthor: boolean; + disabled?: boolean; + userReacted: boolean; } -export const MessageEmojiButton = ({ emoji, checked, onClick, showCount }: MessageEmojiButtonProps) => { +export const MessageEmojiButton = ({ + emoji, + checked, + onClick, + userIsAuthor, + disabled, + userReacted, +}: MessageEmojiButtonProps) => { const EmojiIcon = emojiIcons.get(emoji.name); - if (!EmojiIcon) return <>; + const isAdmin = useHasRole("admin"); + + if (!EmojiIcon) return null; + + const isDisabled = !!(userIsAuthor ? true : disabled); + const showCount = (emoji.count > 0 && userReacted) || userIsAuthor || isAdmin; + return ( ); }; diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index a82cc50c..c469de70 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -116,7 +116,8 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE key={emoji} emoji={{ name: emoji, count }} checked={emojiState.user_emojis.includes(emoji)} - showCount={emojiState.user_emojis.filter((emoji) => emoji === "+1" || emoji === "-1").length > 0} + userReacted={emojiState.user_emojis.length > 0} + userIsAuthor={message.user_is_author} onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))} /> ); diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index ab05d065..5c51c585 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -7,7 +7,7 @@ import { get } from "src/lib/api"; import type { FetchUsersResponse, User } from "src/types/Users"; import useSWR from "swr"; -import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable"; +import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable/DataTable"; interface Pagination { /** diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 826fc195..2ec76c85 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -314,8 +314,8 @@ export class OasstApiClient { return this.get(`/api/v1/messages?${params}`); } - fetch_recent_messages() { - return this.get(`/api/v1/messages`); + fetch_recent_messages(lang: string) { + return this.get(`/api/v1/messages`, { lang }); } fetch_message_children(messageId: string) { diff --git a/website/src/pages/admin/trollboard.tsx b/website/src/pages/admin/trollboard.tsx index aaef80f3..47e69d6b 100644 --- a/website/src/pages/admin/trollboard.tsx +++ b/website/src/pages/admin/trollboard.tsx @@ -18,7 +18,7 @@ const Leaderboard = () => { - {t("leaderboard")} + Trollboard diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index fbcaee3c..978ed3ff 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,9 +1,11 @@ import { withoutRole } from "src/lib/auth"; import { createApiClient } from "src/lib/oasst_client_factory"; +import { getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const client = await createApiClient(token); - const messages = await client.fetch_recent_messages(); + const userLanguage = getUserLanguage(req); + const messages = await client.fetch_recent_messages(userLanguage); res.status(200).json(messages); }); diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 8d950a2b..4cae792a 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -1,6 +1,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; import { useTranslation } from "next-i18next"; +import { useCookies } from "react-cookie"; import { getDashboardLayout } from "src/components/Layout"; import { MessageTable } from "src/components/Messages/MessageTable"; import { get } from "src/lib/api"; @@ -15,6 +16,8 @@ const MessagesDashboard = () => { const { data: messages } = useSWRImmutable("/api/messages", get, { revalidateOnMount: true }); const { data: userMessages } = useSWRImmutable(`/api/messages/user`, get, { revalidateOnMount: true }); + const [cookies] = useCookies(["NEXT_LOCALE"]); + const currentLanguage = cookies["NEXT_LOCALE"] || "en"; return ( <> @@ -24,7 +27,9 @@ const MessagesDashboard = () => { - {t("recent_messages")} + {t("recent_messages", { + language: new Intl.DisplayNames([currentLanguage], { type: "language" }).of(currentLanguage), + })}