Merge branch 'main' of github.com:LAION-AI/Open-Assistant

This commit is contained in:
Yannic Kilcher
2023-02-09 15:31:53 +01:00
31 changed files with 457 additions and 187 deletions
+3 -3
View File
@@ -947,6 +947,9 @@ class PromptRepository:
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
if lang is not None:
qry = qry.filter(Message.lang == lang)
if desc:
qry = qry.order_by(Message.created_date.desc(), Message.id.desc())
else:
@@ -955,9 +958,6 @@ class PromptRepository:
if limit is not None:
qry = qry.limit(limit)
if lang is not None:
qry = qry.filter(Message.lang == lang)
return self._add_user_emojis_all(qry)
def update_children_counts(self, message_tree_id: UUID):
+76 -52
View File
@@ -1,62 +1,18 @@
# Train using supervised examples
Requirements
## Requirements
```
wandb
evaluate
datasets
transformers
torch
```
`pip install -r requirements.txt`
Start training reward model
Start training SFT model
```bash
python trainer.py --configs defaults galactica-125
python trainer.py --configs defaults galactica-125m
```
## Dataset
For now we only support webgpt and summary dataset from OpenAI. Once
open-asisstant dataset are available it will be added here.
## Model
Normally you should be able to add new models in configs/config.yml
```
your-model-name:
learning_rate: 2e-6
model_name: <huggingface model name>
weight_decay: 0.01
max_length: 812
warmup_steps: 600
gradient_checkpointing: false
gradient_accumulation_steps: 5
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
```
```
python trainer.py --configs defaults your-model-name
```
However, if the model of your choice doesn't have pad_token, eos_token,
sep_token, you have to update utils.py `get_tokenizer` to use the right token.
## Deepspeed support
You can edit the configs/zero_config.json and use any stage you wish. The
current config uses zero-stage 3. For more details on how to setup the config
checkout [this page](https://www.deepspeed.ai/tutorials/zero/)
Once you are satisfy with your deepzero config, you can add --deepspeed flag at
the end to trigger deepspeed
```
python trainer.py --configs defaults your-model-name --deepspeed
```
For `wandb`: update the `entity` argument in `trainer.py`'s call to `wandb.init`
to be your weights and biases username per
[docs](https://docs.wandb.ai/ref/python/init).
## Dataset choices
@@ -80,6 +36,74 @@ Currently only these languages are supported via prompt translation:
ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh
```
## Dataset sub-sampling
We can subsample the **training** data by passing either the `fraction` or
`size` argument in the `configs/config.yml` file. Don't forget the additional
colon ":" after the dataset name when doing this.
Example:
```
datasets:
- webgpt:
fraction : 0.05
- prompt_dialogue:
size : 500
- adversarial_qa
- trivia_qa_nocontext
```
In this example, per epoch we will use:
- A random 5% of `webgpt`;
- A random 500 examples from `prompt_dialogue`;
- All examples from datasets for which we don't specify the `fraction` or `size`
argument.
In the above example, per epoch we'll use a different 5% from `webgpt` and a
different 500 examples from `prompt_dialogue`.
This works with `torch.distributed`.
## Model
Normally you should be able to add new models in `configs/config.yml`
```
your-model-name:
learning_rate: 2e-6
model_name: <huggingface model name>
weight_decay: 0.01
max_length: 812
warmup_steps: 600
gradient_checkpointing: false
gradient_accumulation_steps: 5
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
```
```
python trainer.py --configs defaults your-model-name
```
However, if the model of your choice doesn't have `pad_token`, `eos_token`,
`sep_token`, you have to update `get_tokenizer` in `utils.py` to use the right
token.
## Deepspeed support
You can edit the configs/zero_config.json and use any stage you wish. The
current config uses zero-stage 3. For more details on how to setup the config
checkout [this page](https://www.deepspeed.ai/tutorials/zero/).
Once you are satisfy with your deepzero config, you can add --deepspeed flag at
the end to trigger deepspeed
```
python trainer.py --configs defaults your-model-name --deepspeed
```
## Results
Experimental results in wandb
@@ -87,7 +111,7 @@ Experimental results in wandb
## TODOS
- decide on a model
- Decide on a model
- Merge utils etc with reward model
- Casual Modelling for GPT-JT does not leverage the bidirectional mask for the
prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1)
@@ -17,7 +17,7 @@ defaults:
freeze_layer:
datasets:
- webgpt
- prompt_dialogue
# - prompt_dialogue
- squad_v2
- adversarial_qa
- trivia_qa_nocontext
@@ -219,7 +219,7 @@ class SODA(Dataset):
return pairs
def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None:
def __init__(self, cache_dir, input_max_length=1024) -> None:
super().__init__()
self.pairs = []
@@ -230,9 +230,6 @@ class SODA(Dataset):
if len(prompt) < input_max_length:
self.pairs.append((prompt, answer))
if len(self.pairs) > max_sample_size:
break
def __len__(self):
return len(self.pairs)
@@ -100,8 +100,6 @@ class WMT2019(TranslationPair):
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
if len(self.pairs) > 100000:
break
class DiveMT(TranslationPair):
+1 -1
View File
@@ -4,7 +4,6 @@ datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
gdown
mpi4py==3.1.4
nltk==3.8.1
numpy>=1.22.4
py7zr
@@ -12,3 +11,4 @@ PyYAML>=6.0
scikit_learn==1.2.0
torch>=1.11.0
transformers==4.25.1
wandb
+33 -5
View File
@@ -9,7 +9,7 @@ from efficiency_utils import fuse_gelu
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
def compute_metrics(eval_pred, preprocess_fns, metrics):
@@ -31,6 +31,7 @@ class SFTTrainer(Trainer):
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
sampler: torch.utils.data.sampler.Sampler = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
@@ -39,6 +40,7 @@ class SFTTrainer(Trainer):
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function, poly_eps)
self.sampler = sampler
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
@@ -89,6 +91,32 @@ class SFTTrainer(Trainer):
return (loss, logits, labels)
def get_train_dataloader(self):
"""Inject custom data sampling behaviour into training loop"""
if self.sampler is None:
torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
collate_fn=self.data_collator,
)
else:
dataloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
sampler=self.sampler,
collate_fn=self.data_collator,
)
if torch.cuda.device_count() <= 1:
return dataloader
else:
# Not strictly necessary to use accelerate, currently just
# ensures batches are padded to be divisible by # devices
from accelerate import Accelerator
accelerator = Accelerator()
return accelerator.prepare(dataloader)
def _strtobool(x):
return bool(strtobool(x))
@@ -142,8 +170,8 @@ if __name__ == "__main__":
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
if training_conf.quantization:
@@ -181,7 +209,6 @@ if __name__ == "__main__":
)
assert len(evals) > 0
if not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb
@@ -192,8 +219,9 @@ if __name__ == "__main__":
)
trainer = SFTTrainer(
model,
args,
model=model,
args=args,
sampler=sampler,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
+83 -6
View File
@@ -1,11 +1,8 @@
# from functools import partial
import random
from pathlib import Path
from typing import NamedTuple
from typing import List, NamedTuple
import evaluate
# import nltk
# import numpy as np
import transformers
import yaml
from custom_datasets import get_one_dataset
@@ -15,6 +12,79 @@ from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from torch.utils.data.sampler import Sampler
class PerDatasetSampler(Sampler):
"""Sampler which returns a fixed number of samples per dataset, per epoch.
Example:
Dataset 1 has 10,000 examples and we want 200 per epoch
Dataset 2 has 500 examples and we want all 500 per epoch
Epoch size will be 700 and every epoch we'll sample a different
200 from dataset 1.
Parameters
----------
dataset_sizes : List[int]
A list with the size of each dataset.
dataset_size_per_epoch : List[int]
How many examples to get from each dataset per epoch.
Note: dataset_sizes & dataset_size_per_epoch must be in the same order.
Further the examples in the underlying torch.utils.data.Dataset
must per ordered as dataset_1, dataset_2, ..., dataset_n. This is fine
if we concatenate a bunch of datasets together
e.g. using torch.utils.data.ConcatDataset which is current behaviour.
"""
def __init__(self, dataset_sizes: List[int], dataset_size_per_epoch: List[int]):
self.dataset_sizes = dataset_sizes
self.dataset_size_per_epoch = dataset_size_per_epoch
self.num_datasets = len(dataset_sizes)
def __iter__(self):
epoch_idx = []
n = 0
for i in range(self.num_datasets):
sampled_idx = random.sample(range(n, self.dataset_sizes[i] + n), self.dataset_size_per_epoch[i])
n += self.dataset_sizes[i]
epoch_idx.extend(sampled_idx)
random.shuffle(epoch_idx)
return iter(epoch_idx)
def __len__(self):
return int(sum(self.dataset_size_per_epoch))
@classmethod
def build_sampler_from_config(cls, training_conf, datasets):
dataset_sizes = [len(x) for x in datasets]
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes)
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)]
return cls(dataset_sizes, dataset_size_per_epoch)
def get_dataset_fractions(conf, dataset_sizes):
"""Calculate fraction of each dataset to use per epoch when subsampling"""
fractions = []
for i, data_config in enumerate(conf):
dataset_name = get_dataset_name_from_data_config(data_config)
if isinstance(data_config, dict):
if "fraction" in data_config[dataset_name]:
if data_config[dataset_name]["fraction"] <= 0:
raise ValueError("Please specify fraction as a value between 0 < fraction <= 1")
fractions.append(min(1, data_config[dataset_name]["fraction"]))
elif "size" in data_config[dataset_name]:
if data_config[dataset_name]["size"] > dataset_sizes[i]:
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
else:
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
else:
fractions.append(1)
return fractions
class SpecialTokens(NamedTuple):
@@ -140,10 +210,17 @@ def get_model(conf, tokenizer):
return model
def get_dataset_name_from_data_config(data_config):
if isinstance(data_config, dict):
return list(data_config.keys())[0]
return data_config
def get_dataset(conf, tokenizer):
train_datasets, evals = [], {}
for dataset_name in conf.datasets:
for data_config in conf.datasets:
dataset_name = get_dataset_name_from_data_config(data_config)
train, val = get_one_dataset(conf, dataset_name)
train_datasets.append(train)
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
+8
View File
@@ -0,0 +1,8 @@
{
"grab_a_task": "Tag en opgave!",
"create": "Lav",
"evaluate": "Evaluer",
"label": "Label",
"dashboard": "Dashboard",
"go": "Start"
}
+1 -1
View File
@@ -8,7 +8,7 @@
"open_new_tab_action": "Open in new tab",
"parent": "Parent",
"reactions": "Reactions",
"recent_messages": "Recent Messages",
"recent_messages": "Recent Messages in {{language}}",
"report_action": "Report",
"report_placeholder": "Why should this message be reviewed?",
"report_title": "Report",
+3 -10
View File
@@ -3,21 +3,17 @@
"account_settings": "Tài khoản",
"admin_dashboard": "Trang cho admin",
"connect": "Liên hệ",
"conversational": "Chatbot AI cho tất cả mọi người",
"copied": "Copied",
"conversational": "Chatbot AI cho mọi người",
"copied": "Đã sao chép",
"dark_mode": "Giao diện tối",
"dashboard_home": "Trang chính",
"dashboard": "Trang chính",
"delete": "Xoá",
"discord": "Discord",
"docs": "Hướng dẫn",
"github": "GitHub",
"leaderboard": "Bảng xếp hạng",
"legal": "Luật lệ",
"light_mode": "Giao diện sáng",
"loading": "Đang tải...",
"messages_dashboard": "Bảng xếp hạng tin nhắn",
"messages": "Tin nhắn",
"more_information": "Xem thêm",
"no": "Không",
"privacy_policy": "Chính sách bảo mật",
@@ -26,11 +22,8 @@
"sign_out": "Đăng xuất",
"status_dashboard": "Trang hiện tình trạng",
"status": "Tình trạng",
"success": "Success",
"success": "Thành công",
"terms_of_service": "Điều khoản sử dụng",
"title": "Open Assistant",
"user_leaderboard": "Bảng xếp hạng người dùng",
"users_dashboard": "Trang về người dùng",
"users": "Người dùng",
"yes": "Có"
}
+4 -4
View File
@@ -1,15 +1,15 @@
{
"blurb": "Đây sẽ là cuộc cách mạng công nghệ mới.",
"blurb1": "Giống như cách Stable Diffusion đã cho mọi người công cụ để làm tranh ảnh bằng AI, Open Assistant sẽ làm như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.",
"description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện viên trên toàn thế giới.",
"blurb1": "Giống như Stable Diffusion với tranh ảnh bằng AI, Open Assistant sẽ làm tương tự như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.",
"description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện.",
"faq_items": {
"q0": "Open Assistant bây giờ thế nào rồi?",
"a0": "Dự án này đang trong giai đoạn phát triển, từ những nghiên cứu về sử dụng RLHF (học từ phản hồi con người) trong các mô hình ngôn ngữ lớn.",
"q1": "Open Assistant được phát triển bởi ai?",
"a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên trên toàn thế giới."
"a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên."
},
"faq_title": "Câu hỏi",
"join_us_description": "Các dự án mã nguồn mở được phát triển bởi những người như bạn. Triết lý mã nguồn mở là hợp tác để tạo và phát triển công nghệ mới mà làm giàu thế giới quanh ta. Bạn có muốn tham gia không? Liên hệ chúng tôi ở đây:",
"join_us_title": "Tham gia",
"subtitle": "Chatbot AI cho tất cả mọi người"
"subtitle": "Chatbot AI cho mọi người"
}
+3 -3
View File
@@ -1,15 +1,15 @@
{
"daily": "Ngày",
"label": "Nhãn",
"label": "Số nhãn",
"last_updated_at": "Cập nhật lần cuối: {{val, datetime}}",
"leaderboard": "Bảng xếp hạng",
"monthly": "Tháng",
"next": "Tiếp",
"overall": "Tổng quan",
"previous": "Trước",
"prompt": "Câu đầu tiên",
"prompt": "Số câu đầu",
"rank": "Xếp hạng",
"reply": "Câu trả lời",
"reply": "Số câu trả lời",
"score": "Điểm",
"top_5_contributors_today": "Top 5 người đóng góp",
"user": "Tên người dùng",
+5 -5
View File
@@ -1,20 +1,20 @@
{
"copy_message_id": "Copy message ID",
"copy_message_id": "Sao chép ID",
"label_action": "Nhãn",
"label_title": "Nhãn",
"message": "Tin nhắn",
"message_deleted": "Message deleted",
"message_deleted": "Tin nhắn đã xoá",
"open_new_tab_action": "Mở ở trang mới",
"parent": "Tin nhắn gốc",
"reactions": "Bình luận",
"recent_messages": "Tin nhắn gần đây",
"report_action": "Báo cáo",
"report_placeholder": "Tại sao tin nhắn này cần được báo cáo?",
"report_placeholder": "Nêu lý do để báo cáo",
"report_title": "Báo cáo",
"send_report": "Gửi",
"stop_tree": "Stop tree",
"stop_tree": "Dừng nhánh tin nhắn",
"submit_labels": "Gửi",
"tree_stopped": "Tree stopped {{id}}",
"tree_stopped": "Nhánh tin nhắn {{id}} đã dừng",
"view_user": "Xem người dùng",
"your_recent_messages": "Tin nhắn gần đây của bạn"
}
+12
View File
@@ -0,0 +1,12 @@
{
"dashboard": "Trang chính",
"dashboard_home": "Trang chính",
"leaderboard": "Bảng xếp hạng",
"messages": "Tin nhắn",
"messages_dashboard": "Bảng xếp hạng tin nhắn",
"status": "Tình trạng",
"status_dashboard": "Trang hiện tình trạng",
"user_leaderboard": "Bảng xếp hạng người dùng",
"users": "Người dùng",
"users_dashboard": "Trang về người dùng"
}
+13 -13
View File
@@ -1,22 +1,22 @@
{
"available_task_count": "{{count}} việc",
"classify_assistant_reply": {
"label": "Phân loại tin nhắn của Open Assistant",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.",
"label": "Phân loại các tin nhắn của Open Assistant",
"desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.",
"overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện."
},
"classify_initial_prompt": {
"label": "Phân loại tin nhắn đầu",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.",
"label": "Phân loại các tin nhắn đầu tiên",
"desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.",
"overview": "Đọc tin nhắn đầu và trả lời các câu hỏi."
},
"classify_prompter_reply": {
"label": "Phân loại tin nhắn người dùng",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.",
"desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.",
"overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện."
},
"create_initial_prompt": {
"label": "Tạo tin nhắn đầu",
"label": "Tạo tin nhắn đầu tiên",
"desc": "Viết tin nhắn đầu tiên để làm bộ dữ liệu cho Open Assistant.",
"overview": "Viết tin nhắn đầu tiên để Open Assistant trả lời",
"instruction": "Viết tin nhắn đầu",
@@ -27,17 +27,17 @@
"unchanged_message": "Are you sure you would like to continue?"
},
"label_assistant_reply": {
"label": "Tạo nhãn cho tin nhắn của Open Assistant",
"label": "Đánh giá các tin nhắn của Open Assistant",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của Open Assistant.",
"overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau."
},
"label_initial_prompt": {
"label": "Tạo nhãn cho tin nhắn đầu",
"label": "Đánh giá các tin nhắn đầu tiên",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn đầu.",
"overview": "Tạo nhãn dữ liệu cho tin nhắn sau."
},
"label_prompter_reply": {
"label": "Tạo nhãn cho tin nhắn người dùng",
"label": "Đánh giá các tin nhắn người dùng",
"desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của người dùng.",
"overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau."
},
@@ -46,28 +46,28 @@
"desc": "Giúp cải thiện Open Assistant bằng cách làm một việc ngẫu nhiên."
},
"rank_assistant_replies": {
"label": "Xếp hạng câu trả lời của Open Assistant",
"label": "Xếp hạng các câu trả lời của Open Assistant",
"desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời mà Open Assistant đưa ra.",
"overview": "Từ những câu trả lời của Open Assistant, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.",
"unchanged_title": "Chưa thay đổi thứ tự",
"unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?"
},
"rank_initial_prompts": {
"label": "Xếp hạng tin nhắn đầu tiên",
"label": "Xếp hạng các tin nhắn đầu tiên",
"desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời của tin nhắn đầu tiên.",
"overview": "Từ những tin nhắn đầu sau, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.",
"unchanged_title": "Chưa thay đổi thứ tự",
"unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?"
},
"rank_user_replies": {
"label": "Xếp hạng câu trả lời của người dùng",
"label": "Xếp hạng các câu trả lời của người dùng",
"desc": "Giúp cải thiện câu trả lời của Open Assistant.",
"overview": "Từ những câu trả lời của người dùng, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.",
"unchanged_title": "Chưa thay đổi thứ tự",
"unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?"
},
"reply_as_assistant": {
"label": "Đóng vai Open Assistant",
"label": "Đóng vai trợ lý",
"desc": "Giúp cải thiện câu trả lời của Open Assistant.",
"overview": "Tạo câu trả lời phù hợp cho cuộc trò truyện dưới đây",
"response_placeholder": "Viết vào đây..."
@@ -23,13 +23,23 @@ import {
Tr,
useDisclosure,
} from "@chakra-ui/react";
import { Cell, ColumnDef, flexRender, getCoreRowModel, Row, useReactTable } from "@tanstack/react-table";
import {
Cell,
ColumnDef,
ExpandedState,
flexRender,
getCoreRowModel,
getExpandedRowModel,
Row,
useReactTable,
} from "@tanstack/react-table";
import { Filter } from "lucide-react";
import { useTranslation } from "next-i18next";
import { ChangeEvent, ReactNode } from "react";
import { ChangeEvent, ReactNode, useState } from "react";
import { useDebouncedCallback } from "use-debounce";
export type DataTableColumnDef<T> = ColumnDef<T> & {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type DataTableColumnDef<T> = ColumnDef<T, any> & {
filterable?: boolean;
span?: number | ((cell: Cell<T, unknown>) => number | undefined);
};
@@ -54,6 +64,7 @@ export type DataTableProps<T> = {
disablePrevious?: boolean;
disablePagination?: boolean;
rowProps?: TableRowProps | DataTableRowPropsCallback<T>;
getSubRows?: (row: T) => T[] | undefined;
};
export const DataTable = <T,>({
@@ -68,12 +79,21 @@ export const DataTable = <T,>({
disablePrevious,
disablePagination,
rowProps,
getSubRows,
}: DataTableProps<T>) => {
const { t } = useTranslation("leaderboard");
const [expanded, setExpanded] = useState<ExpandedState>({});
const { getHeaderGroups, getRowModel } = useReactTable<T>({
data,
columns,
getCoreRowModel: getCoreRowModel(),
getExpandedRowModel: getExpandedRowModel(),
state: {
expanded,
},
getSubRows,
onExpandedChange: setExpanded,
});
const handleFilterChange = (value: FilterItem) => {
@@ -0,0 +1,60 @@
import { Card, CardBody, Flex } from "@chakra-ui/react";
import { Cell, CellContext } from "@tanstack/react-table";
import { ChevronDown, ChevronRight } from "lucide-react";
type ExpandableRow<T> = Omit<T, "shouldExpand"> & {
shouldExpand?: boolean;
};
export const createJsonExpandRowModel = <T,>() => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const renderCell = ({ row, getValue }: CellContext<ExpandableRow<T>, any>) => {
if (!row.original.shouldExpand) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { shouldExpand, ...res } = row.original;
return (
<Card variant="json">
<CardBody>
<pre>{JSON.stringify(res, null, 2)}</pre>
</CardBody>
</Card>
);
}
return (
<Flex alignItems="center">
{row.getCanExpand() ? (
<button
{...{
onClick: row.getToggleExpandedHandler(),
style: { cursor: "pointer" },
}}
>
{row.getIsExpanded() ? <ChevronDown /> : <ChevronRight />}
</button>
) : null}{" "}
{getValue()}
</Flex>
);
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const span = (cell: Cell<ExpandableRow<T>, any>) =>
cell.row.original.shouldExpand ? undefined : cell.row.getVisibleCells().length;
const getSubRows = (row: ExpandableRow<T>) =>
row.shouldExpand
? [
{
...row,
shouldExpand: false,
},
]
: undefined;
const toExpandable = function (arr: T[] | undefined, val = true): ExpandableRow<T>[] {
return !arr ? [] : arr.map((element) => ({ ...element, shouldExpand: val }));
};
return { renderCell, span, getSubRows, toExpandable };
};
+1 -1
View File
@@ -69,7 +69,7 @@ export function UserMenu() {
<Menu>
<MenuButton border="solid" borderRadius="full" borderWidth="thin" borderColor={borderColor}>
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
<Avatar size="sm" bgImage={session.user.image}></Avatar>
<Avatar size="sm" src={session.user.image}></Avatar>
<Text data-cy="username" className="hidden lg:flex">
{session.user.name || "New User"}
</Text>
@@ -2,19 +2,20 @@ import { Box, CircularProgress, Flex, Link, useColorModeValue } from "@chakra-ui
import { createColumnHelper } from "@tanstack/react-table";
import { MoreHorizontal } from "lucide-react";
import NextLink from "next/link";
import { useSession } from "next-auth/react";
import { useTranslation } from "next-i18next";
import React, { useMemo } from "react";
import { useHasRole } from "src/hooks/auth/useHasRole";
import { LeaderboardEntity, LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import { DataTable, DataTableColumnDef } from "../DataTable";
import { DataTable, DataTableColumnDef } from "../DataTable/DataTable";
import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel";
import { useBoardPagination } from "./useBoardPagination";
import { useBoardRowProps } from "./useBoardRowProps";
import { useFetchBoard } from "./useFetchBoard";
type WindowLeaderboardEntity = LeaderboardEntity & { isSpaceRow?: boolean };
const columnHelper = createColumnHelper<WindowLeaderboardEntity>();
const jsonExpandRowModel = createJsonExpandRowModel<WindowLeaderboardEntity>();
/**
* Presents a grid of leaderboard entries with more detailed information.
*/
@@ -39,17 +40,24 @@ export const LeaderboardTable = ({
} = useFetchBoard<LeaderboardReply & { user_stats_window?: LeaderboardReply["leaderboard"] }>(
`/api/leaderboard?time_frame=${timeFrame}&limit=${limit}&includeUserStats=${!hideCurrentUserRanking}`
);
const { data: session } = useSession();
const isAdmin = session?.user?.role === "admin";
const isAdmin = useHasRole("admin");
const columns: DataTableColumnDef<WindowLeaderboardEntity>[] = useMemo(
() => [
{
...columnHelper.accessor("rank", {
header: t("rank"),
cell: ({ row, getValue }) => (row.original.isSpaceRow ? <SpaceRow></SpaceRow> : getValue()),
cell: (ctx) =>
ctx.row.original.isSpaceRow ? (
<SpaceRow></SpaceRow>
) : isAdmin ? (
jsonExpandRowModel.renderCell(ctx)
) : (
ctx.getValue()
),
}),
span: (cell) => (cell.row.original.isSpaceRow ? 6 : undefined),
span: (cell) => (cell.row.original.isSpaceRow ? 6 : jsonExpandRowModel.span(cell)),
},
columnHelper.accessor("display_name", {
header: t("user"),
@@ -82,17 +90,17 @@ export const LeaderboardTable = ({
data: paginatedData,
end,
...pagnationProps
} = useBoardPagination({ rowPerPage, data: reply?.leaderboard, limit });
const data: WindowLeaderboardEntity[] = useMemo(() => {
if (hideCurrentUserRanking || !reply?.user_stats_window) {
} = useBoardPagination({ rowPerPage, data: jsonExpandRowModel.toExpandable(reply?.leaderboard || []), limit });
const data = useMemo(() => {
if (hideCurrentUserRanking || !reply?.user_stats_window || reply.user_stats_window.length === 0) {
return paginatedData;
}
const userStatsWindow: WindowLeaderboardEntity[] = reply.user_stats_window;
const userStatsWindow: WindowLeaderboardEntity[] = jsonExpandRowModel.toExpandable(reply.user_stats_window);
const userStats = userStatsWindow.find((stats) => stats.highlighted);
if (userStats && userStats.rank > end) {
paginatedData.push(
{ isSpaceRow: true } as WindowLeaderboardEntity,
...reply.user_stats_window.filter(
...userStatsWindow.filter(
(stats) => paginatedData.findIndex((leaderBoardEntity) => leaderBoardEntity.user_id === stats.user_id) === -1
) // filter to avoid duplicated row
);
@@ -116,6 +124,7 @@ export const LeaderboardTable = ({
columns={columns}
caption={lastUpdated}
rowProps={rowProps}
getSubRows={jsonExpandRowModel.getSubRows}
{...pagnationProps}
></DataTable>
);
@@ -1,26 +1,42 @@
import { Box, CircularProgress, Flex, Link } from "@chakra-ui/react";
import { Box, CircularProgress, Flex, IconButton, Link, Tooltip } from "@chakra-ui/react";
import { createColumnHelper } from "@tanstack/react-table";
import { ThumbsDown, ThumbsUp } from "lucide-react";
import { Mail, ThumbsDown, ThumbsUp, User } from "lucide-react";
import NextLink from "next/link";
import { FetchTrollBoardResponse, TrollboardEntity, TrollboardTimeFrame } from "src/types/Trollboard";
import { DataTable } from "../DataTable";
import { DataTable, DataTableColumnDef } from "../DataTable/DataTable";
import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel";
import { Discord } from "../Icons/Discord";
import { useBoardPagination } from "./useBoardPagination";
import { useBoardRowProps } from "./useBoardRowProps";
import { useFetchBoard } from "./useFetchBoard";
const columnHelper = createColumnHelper<TrollboardEntity>();
const toPercentage = (num: number) => `${Math.round(num * 100)}%`;
const jsonExpandRowModel = createJsonExpandRowModel<TrollboardEntity>();
const columns = [
columnHelper.accessor("rank", {}),
const columns: DataTableColumnDef<TrollboardEntity>[] = [
{
...columnHelper.accessor("rank", {
cell: jsonExpandRowModel.renderCell,
}),
span: jsonExpandRowModel.span,
},
columnHelper.accessor("display_name", {
header: "Display name",
cell: ({ getValue, row }) => (
<Link as={NextLink} href={`/admin/manage_user/${row.original.user_id}`}>
{getValue()}
</Link>
),
cell: ({ getValue, row }) => {
const isEmail = row.original.auth_method === "local";
return (
<Flex gap="2" alignItems="center">
<Link as={NextLink} href={`/admin/manage_user/${row.original.user_id}`}>
{getValue()}
</Link>
<Tooltip label={`This user signin with ${isEmail ? "email" : "discord"}`}>
{isEmail ? <Mail size="20"></Mail> : <Discord size="20"></Discord>}
</Tooltip>
</Flex>
);
},
}),
columnHelper.accessor("troll_score", {
header: "Troll score",
@@ -45,36 +61,19 @@ const columns = [
columnHelper.accessor((row) => row.spam + row.spam_prompts, {
header: "Spam",
}),
columnHelper.accessor("lang_mismach", {
header: "Lang mismach",
}),
columnHelper.accessor("not_appropriate", {
header: "Not appropriate",
}),
columnHelper.accessor("pii", {}),
columnHelper.accessor("hate_speech", {
header: "Hate speech",
}),
columnHelper.accessor("sexual_content", {
header: "Sexual Content",
}),
columnHelper.accessor("political_content", {
header: "Political Content",
}),
columnHelper.accessor("quality", {
cell: ({ getValue }) => toPercentage(getValue()),
}),
columnHelper.accessor("helpfulness", {
cell: ({ getValue }) => toPercentage(getValue()),
}),
columnHelper.accessor("humor", {
cell: ({ getValue }) => toPercentage(getValue()),
}),
columnHelper.accessor("violence", {
cell: ({ getValue }) => toPercentage(getValue()),
}),
columnHelper.accessor("toxicity", {
cell: ({ getValue }) => toPercentage(getValue()),
cell: ({ getValue }) => toPercentage(getValue() || 0),
}),
columnHelper.accessor((row) => row.user_id, {
header: "Actions",
cell: ({ row }) => (
<IconButton
as={NextLink}
href={`/admin/manage_user/${row.original.user_id}`}
aria-label={"View user"}
icon={<User></User>}
></IconButton>
),
}),
];
@@ -94,7 +93,11 @@ export const TrollboardTable = ({
lastUpdated,
} = useFetchBoard<FetchTrollBoardResponse>(`/api/admin/trollboard?time_frame=${timeFrame}&limit=${limit}`);
const { data, ...paginationProps } = useBoardPagination({ rowPerPage, data: trollboardRes?.trollboard, limit });
const { data, ...paginationProps } = useBoardPagination<TrollboardEntity>({
rowPerPage,
data: jsonExpandRowModel.toExpandable(trollboardRes?.trollboard),
limit,
});
const rowProps = useBoardRowProps<TrollboardEntity>();
if (isLoading) {
return <CircularProgress isIndeterminate></CircularProgress>;
@@ -112,11 +115,12 @@ export const TrollboardTable = ({
},
}}
>
<DataTable<TrollboardEntity>
<DataTable
data={data}
columns={columns}
caption={lastUpdated}
rowProps={rowProps}
getSubRows={jsonExpandRowModel.getSubRows}
{...paginationProps}
></DataTable>
</Box>
@@ -2,7 +2,7 @@ import { useColorModeValue, useToken } from "@chakra-ui/react";
import { useCallback } from "react";
import { colors } from "src/styles/Theme/colors";
import { DataTableRowPropsCallback } from "../DataTable";
import { DataTableRowPropsCallback } from "../DataTable/DataTable";
export const useBoardRowProps = <T extends { highlighted: boolean }>() => {
const borderColor = useToken("colors", useColorModeValue(colors.light.active, colors.dark.active));
@@ -11,17 +11,16 @@ export default {
const Template = ({
emoji,
count,
checked,
showCount,
...rest
}: {
emoji: string;
count: number;
checked?: boolean;
showCount: boolean;
userIsAuthor: boolean;
disabled?: boolean;
userReacted: boolean;
}) => {
return (
<MessageEmojiButton emoji={{ name: emoji, count }} checked={checked} onClick={undefined} showCount={showCount} />
);
return <MessageEmojiButton emoji={{ name: emoji, count }} onClick={undefined} {...rest} />;
};
export const Default = Template.bind({});
@@ -29,7 +28,9 @@ Default.args = {
emoji: "+1",
count: 7,
checked: false,
showCount: true,
userIsAuthor: false,
disabled: false,
userReacted: true,
};
export const BigNumber = Template.bind({});
@@ -1,4 +1,5 @@
import { Button } from "@chakra-ui/react";
import { useHasRole } from "src/hooks/auth/useHasRole";
import { MessageEmoji } from "src/types/Conversation";
import { emojiIcons } from "src/types/Emoji";
@@ -6,12 +7,27 @@ interface MessageEmojiButtonProps {
emoji: MessageEmoji;
checked?: boolean;
onClick: () => void;
showCount: boolean;
userIsAuthor: boolean;
disabled?: boolean;
userReacted: boolean;
}
export const MessageEmojiButton = ({ emoji, checked, onClick, showCount }: MessageEmojiButtonProps) => {
export const MessageEmojiButton = ({
emoji,
checked,
onClick,
userIsAuthor,
disabled,
userReacted,
}: MessageEmojiButtonProps) => {
const EmojiIcon = emojiIcons.get(emoji.name);
if (!EmojiIcon) return <></>;
const isAdmin = useHasRole("admin");
if (!EmojiIcon) return null;
const isDisabled = !!(userIsAuthor ? true : disabled);
const showCount = (emoji.count > 0 && userReacted) || userIsAuthor || isAdmin;
return (
<Button
onClick={onClick}
@@ -21,9 +37,16 @@ export const MessageEmojiButton = ({ emoji, checked, onClick, showCount }: Messa
height="1.6em"
minWidth={0}
padding="0"
disabled={disabled}
sx={{
":hover": {
backgroundColor: isDisabled ? "transparent" : undefined,
},
}}
color={isDisabled ? "gray.500" : undefined}
>
<EmojiIcon style={{ height: "1em" }} />
{emoji.count > 0 && showCount && <span style={{ marginInlineEnd: "0.25em" }}>{emoji.count}</span>}
{showCount && <span style={{ marginInlineEnd: "0.25em" }}>{emoji.count}</span>}
</Button>
);
};
@@ -116,7 +116,8 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
key={emoji}
emoji={{ name: emoji, count }}
checked={emojiState.user_emojis.includes(emoji)}
showCount={emojiState.user_emojis.filter((emoji) => emoji === "+1" || emoji === "-1").length > 0}
userReacted={emojiState.user_emojis.length > 0}
userIsAuthor={message.user_is_author}
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
/>
);
+1 -1
View File
@@ -7,7 +7,7 @@ import { get } from "src/lib/api";
import type { FetchUsersResponse, User } from "src/types/Users";
import useSWR from "swr";
import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable";
import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable/DataTable";
interface Pagination {
/**
+2 -2
View File
@@ -314,8 +314,8 @@ export class OasstApiClient {
return this.get<Message[]>(`/api/v1/messages?${params}`);
}
fetch_recent_messages() {
return this.get<Message[]>(`/api/v1/messages`);
fetch_recent_messages(lang: string) {
return this.get<Message[]>(`/api/v1/messages`, { lang });
}
fetch_message_children(messageId: string) {
+1 -1
View File
@@ -18,7 +18,7 @@ const Leaderboard = () => {
<AdminArea>
<Box display="flex" flexDirection="column">
<Heading fontSize="2xl" fontWeight="bold" pb="4">
{t("leaderboard")}
Trollboard
</Heading>
<Card>
<CardBody>
+3 -1
View File
@@ -1,9 +1,11 @@
import { withoutRole } from "src/lib/auth";
import { createApiClient } from "src/lib/oasst_client_factory";
import { getUserLanguage } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const client = await createApiClient(token);
const messages = await client.fetch_recent_messages();
const userLanguage = getUserLanguage(req);
const messages = await client.fetch_recent_messages(userLanguage);
res.status(200).json(messages);
});
+6 -1
View File
@@ -1,6 +1,7 @@
import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { useTranslation } from "next-i18next";
import { useCookies } from "react-cookie";
import { getDashboardLayout } from "src/components/Layout";
import { MessageTable } from "src/components/Messages/MessageTable";
import { get } from "src/lib/api";
@@ -15,6 +16,8 @@ const MessagesDashboard = () => {
const { data: messages } = useSWRImmutable("/api/messages", get, { revalidateOnMount: true });
const { data: userMessages } = useSWRImmutable(`/api/messages/user`, get, { revalidateOnMount: true });
const [cookies] = useCookies(["NEXT_LOCALE"]);
const currentLanguage = cookies["NEXT_LOCALE"] || "en";
return (
<>
<Head>
@@ -24,7 +27,9 @@ const MessagesDashboard = () => {
<SimpleGrid columns={[1, 1, 1, 1, 1, 2]} gap={4}>
<Box>
<Text className="text-2xl font-bold" pb="4">
{t("recent_messages")}
{t("recent_messages", {
language: new Intl.DisplayNames([currentLanguage], { type: "language" }).of(currentLanguage),
})}
</Text>
<Box
backgroundColor={boxBgColor}
+8
View File
@@ -19,6 +19,14 @@ export interface Message extends MessageEmojis {
parent_id: string;
frontend_message_id?: string;
user_id: string;
user_is_author: boolean | null;
deleted: boolean | null;
synthetic: boolean | null;
message_tree_id: string;
ranking_count: number | null;
rank: number | null;
model_name: string | null;
review_count: number | null;
}
export interface Conversation {