From e2caf5365465254e8f5563f1dcbeeed9e32710af Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 08:01:04 +0000 Subject: [PATCH 01/10] First version of single GPU sampling working --- .../supervised_finetuning/configs/config.yaml | 49 +++++++-------- model/supervised_finetuning/requirements.txt | 2 +- model/supervised_finetuning/trainer.py | 36 ++++++++--- model/supervised_finetuning/utils.py | 59 ++++++++++++++++++- 4 files changed, 112 insertions(+), 34 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 1d196fb2..925b6dda 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -16,29 +16,32 @@ defaults: eval_accumulation_steps: freeze_layer: datasets: - - webgpt - - prompt_dialogue - - squad_v2 - - adversarial_qa - - trivia_qa_nocontext - - xsum - - cnn_dailymail - - prompt_dialogue - - multi_news - - scitldr - - soda - - joke - - gsm8k - - dive_mt - - wmt2019_zh-en - - wmt2019_ru-en - - wmt2019_de-en - - ted_trans_nl-en - - ted_trans_de-ja - - instruct_tuning - - wmt2019_de-en - - samsum - - soda_dialogue + - webgpt: + fraction : 0.001 + - prompt_dialogue: + size : 10 + - squad_v2: + size : 10 + # - adversarial_qa + # - trivia_qa_nocontext + # - xsum + # - cnn_dailymail + # - prompt_dialogue + # - multi_news + # - scitldr + # - soda + # - joke + # - gsm8k + # - dive_mt + # - wmt2019_zh-en + # - wmt2019_ru-en + # - wmt2019_de-en + # - ted_trans_nl-en + # - ted_trans_de-ja + # - instruct_tuning + # - wmt2019_de-en + # - samsum + # - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 8f8cc63c..efe6df89 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,7 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -mpi4py==3.1.4 +# mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 0acb10dd..06ae39db 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -8,7 +8,8 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls +from utils import (build_train_sampler, get_dataset, get_loss, get_metrics, + get_model, get_tokenizer, read_yamls) def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -30,6 +31,7 @@ class SFTTrainer(Trainer): self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, + sampler: torch.utils.data.sampler.Sampler = None, loss_function: str = "CrossEntropyLoss", poly_eps: float = 1.0, **kwargs, @@ -38,6 +40,7 @@ class SFTTrainer(Trainer): # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct self.loss_fct = get_loss(loss_function, poly_eps) + self.sampler = sampler def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -88,6 +91,22 @@ class SFTTrainer(Trainer): return (loss, logits, labels) + def get_train_dataloader(self): + if self.sampler is None: + torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=self.data_collator + ) + else: + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + sampler=self.sampler, + collate_fn=self.data_collator + ) + def _strtobool(x): return bool(strtobool(x)) @@ -141,8 +160,8 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + sampler = build_train_sampler(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF if training_conf.quantization: @@ -159,7 +178,7 @@ if __name__ == "__main__": learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, optim=optimizer, - fp16=True, + # fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, @@ -177,19 +196,20 @@ if __name__ == "__main__": ) assert len(evals) > 0 - if not training_conf.deepspeed or training_conf.local_rank == 0: import wandb wandb.init( project="supervised-finetuning", - entity=training_conf.wandb_entity, + # entity=training_conf.wandb_entity, + entity="maw501", name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) - + trainer = SFTTrainer( - model, - args, + model=model, + args=args, + sampler=sampler, loss_function=training_conf.loss_fn, poly_eps=training_conf.poly_eps, train_dataset=train, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f7a0ab15..6d6b271a 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -2,9 +2,10 @@ from pathlib import Path import evaluate +import random # import nltk -# import numpy as np +import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -14,6 +15,35 @@ from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset +from torch.utils.data.sampler import Sampler + + +class ClassSampler(Sampler): + """Sampler which returns a fixed number of samples per class, per epoch""" + def __init__(self, class_labels, class_sizes): + self.class_labels = class_labels + self.class_sizes = class_sizes + + def __iter__(self): + out = [] + for i, _class in enumerate(np.unique(self.class_labels)): + class_idx = np.argwhere(self.class_labels == _class).flatten() + sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i])) + out.extend(sampled_idx) + random.shuffle(out) + return iter(out) + + def __len__(self): + return int(sum(self.class_sizes)) + + +def build_train_sampler(training_conf, datasets): + train_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, train_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)] + dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)] + dataset_labels = [i for s in dataset_labels for i in s] + return ClassSampler(dataset_labels, dataset_size_per_epoch) def get_tokenizer(conf): @@ -115,10 +145,35 @@ def get_model(conf, tokenizer): return model +def get_dataset_name_from_data_config(data_config): + if isinstance(data_config, dict): + return list(data_config.keys())[0] + return data_config + + +def get_dataset_fractions(conf, dataset_sizes): + fractions = [] + for i, data_config in enumerate(conf): + dataset_name = get_dataset_name_from_data_config(data_config) + if isinstance(data_config, dict): + if "fraction" in data_config[dataset_name]: + fractions.append(min(1, data_config[dataset_name]["fraction"])) + elif "size" in data_config[dataset_name]: + if data_config[dataset_name]["size"] > dataset_sizes[i]: + raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})") + fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) + else: + raise ValueError("Please specify either fraction or size in config.yaml") + else: + fractions.append(1) + return fractions + + def get_dataset(conf, tokenizer): train_datasets, evals = [], {} - for dataset_name in conf.datasets: + for data_config in conf.datasets: + dataset_name = get_dataset_name_from_data_config(data_config) train, val = get_one_dataset(conf, dataset_name) train_datasets.append(train) evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val From 283df8ec84bae350c8c13e364807443c36ad071e Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 20:49:25 +0000 Subject: [PATCH 02/10] Get working on multi-gpu --- model/supervised_finetuning/README.md | 128 +++++++++++------- .../supervised_finetuning/configs/config.yaml | 47 +++---- .../custom_datasets/qa_datasets.py | 5 +- .../custom_datasets/translation.py | 2 - model/supervised_finetuning/requirements.txt | 2 +- model/supervised_finetuning/trainer.py | 25 ++-- model/supervised_finetuning/utils.py | 78 +++++++---- 7 files changed, 167 insertions(+), 120 deletions(-) diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index d5b10e01..387e91e4 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -1,62 +1,18 @@ # Train using supervised examples -Requirements +## Requirements -``` -wandb -evaluate -datasets -transformers -torch -``` +`pip install -r requirements.txt` -Start training reward model +Start training SFT model ```bash -python trainer.py --configs defaults galactica-125 +python trainer.py --configs defaults galactica-125m ``` -## Dataset - -For now we only support webgpt and summary dataset from OpenAI. Once -open-asisstant dataset are available it will be added here. - -## Model - -Normally you should be able to add new models in configs/config.yml - -``` -your-model-name: - learning_rate: 2e-6 - model_name: - weight_decay: 0.01 - max_length: 812 - warmup_steps: 600 - gradient_checkpointing: false - gradient_accumulation_steps: 5 - per_device_train_batch_size: 4 - per_device_eval_batch_size: 4 -``` - -``` -python trainer.py --configs defaults your-model-name -``` - -However, if the model of your choice doesn't have pad_token, eos_token, -sep_token, you have to update utils.py `get_tokenizer` to use the right token. - -## Deepspeed support - -You can edit the configs/zero_config.json and use any stage you wish. The -current config uses zero-stage 3. For more details on how to setup the config -checkout [this page](https://www.deepspeed.ai/tutorials/zero/) - -Once you are satisfy with your deepzero config, you can add --deepspeed flag at -the end to trigger deepspeed - -``` -python trainer.py --configs defaults your-model-name --deepspeed -``` +For `wandb`: update the `entity` argument in `trainer.py`'s call to `wandb.init` +to be your weights and biases username per +[docs](https://docs.wandb.ai/ref/python/init). ## Dataset choices @@ -80,6 +36,74 @@ Currently only these languages are supported via prompt translation: ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh ``` +## Dataset sub-sampling + +We can subsample the **training** data by passing either the `fraction` or +`size` argument in the `configs/config.yml` file. Don't forget the additional +colon ":" after the dataset name when doing this. + +Example: + +``` + datasets: + - webgpt: + fraction : 0.05 + - prompt_dialogue: + size : 500 + - adversarial_qa + - trivia_qa_nocontext +``` + +In this example, per epoch we will use: + +- A random 5% of `webgpt`; +- A random 500 examples from `prompt_dialogue`; +- All examples from datasets for which we don't specify the `fraction` or `size` + argument. + +In the above example, per epoch we'll use a different 5% from `webgpt` and a +different 500 examples from `prompt_dialogue`. + +This works with `torch.distributed`. + +## Model + +Normally you should be able to add new models in `configs/config.yml` + +``` +your-model-name: + learning_rate: 2e-6 + model_name: + weight_decay: 0.01 + max_length: 812 + warmup_steps: 600 + gradient_checkpointing: false + gradient_accumulation_steps: 5 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 +``` + +``` +python trainer.py --configs defaults your-model-name +``` + +However, if the model of your choice doesn't have `pad_token`, `eos_token`, +`sep_token`, you have to update `get_tokenizer` in `utils.py` to use the right +token. + +## Deepspeed support + +You can edit the configs/zero_config.json and use any stage you wish. The +current config uses zero-stage 3. For more details on how to setup the config +checkout [this page](https://www.deepspeed.ai/tutorials/zero/). + +Once you are satisfy with your deepzero config, you can add --deepspeed flag at +the end to trigger deepspeed + +``` +python trainer.py --configs defaults your-model-name --deepspeed +``` + ## Results Experimental results in wandb @@ -87,7 +111,7 @@ Experimental results in wandb ## TODOS -- decide on a model +- Decide on a model - Merge utils etc with reward model - Casual Modelling for GPT-JT does not leverage the bidirectional mask for the prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 925b6dda..289e49ad 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -16,32 +16,29 @@ defaults: eval_accumulation_steps: freeze_layer: datasets: - - webgpt: - fraction : 0.001 - - prompt_dialogue: - size : 10 - - squad_v2: - size : 10 - # - adversarial_qa - # - trivia_qa_nocontext - # - xsum - # - cnn_dailymail + - webgpt # - prompt_dialogue - # - multi_news - # - scitldr - # - soda - # - joke - # - gsm8k - # - dive_mt - # - wmt2019_zh-en - # - wmt2019_ru-en - # - wmt2019_de-en - # - ted_trans_nl-en - # - ted_trans_de-ja - # - instruct_tuning - # - wmt2019_de-en - # - samsum - # - soda_dialogue + - squad_v2 + - adversarial_qa + - trivia_qa_nocontext + - xsum + - cnn_dailymail + - prompt_dialogue + - multi_news + - scitldr + - soda + - joke + - gsm8k + - dive_mt + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja + - instruct_tuning + - wmt2019_de-en + - samsum + - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 2c5c7ee2..206b679c 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -219,7 +219,7 @@ class SODA(Dataset): return pairs - def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: + def __init__(self, cache_dir, input_max_length=1024) -> None: super().__init__() self.pairs = [] @@ -230,9 +230,6 @@ class SODA(Dataset): if len(prompt) < input_max_length: self.pairs.append((prompt, answer)) - if len(self.pairs) > max_sample_size: - break - def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index f9a71a8e..008de751 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -100,8 +100,6 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) - if len(self.pairs) > 100000: - break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index efe6df89..95e5a472 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,7 +4,6 @@ datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 gdown -# mpi4py==3.1.4 nltk==3.8.1 numpy>=1.22.4 py7zr @@ -12,3 +11,4 @@ PyYAML>=6.0 scikit_learn==1.2.0 torch>=1.11.0 transformers==4.25.1 +wandb diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 06ae39db..922be10b 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -8,8 +8,7 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import (build_train_sampler, get_dataset, get_loss, get_metrics, - get_model, get_tokenizer, read_yamls) +from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls def compute_metrics(eval_pred, preprocess_fns, metrics): @@ -92,20 +91,30 @@ class SFTTrainer(Trainer): return (loss, logits, labels) def get_train_dataloader(self): + """Inject custom data sampling behaviour into training loop""" if self.sampler is None: torch.utils.data.DataLoader( self.train_dataset, batch_size=self.args.per_device_train_batch_size, shuffle=True, - collate_fn=self.data_collator + collate_fn=self.data_collator, ) else: - return torch.utils.data.DataLoader( + dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=self.args.per_device_train_batch_size, sampler=self.sampler, - collate_fn=self.data_collator + collate_fn=self.data_collator, ) + if torch.cuda.device_count() <= 1: + return dataloader + else: + # Not strictly necessary to use accelerate, currently just + # ensures batches are padded to be divisible by # devices + from accelerate import Accelerator + + accelerator = Accelerator() + return accelerator.prepare(dataloader) def _strtobool(x): @@ -160,7 +169,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - sampler = build_train_sampler(training_conf, train.datasets) + sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets) metrics, preprocess_fns = get_metrics(training_conf, tokenizer) optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF @@ -178,7 +187,7 @@ if __name__ == "__main__": learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, optim=optimizer, - # fp16=True, + fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, gradient_accumulation_steps=training_conf.gradient_accumulation_steps, @@ -205,7 +214,7 @@ if __name__ == "__main__": entity="maw501", name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) - + trainer = SFTTrainer( model=model, args=args, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 6d6b271a..c18cb380 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,11 +1,8 @@ -# from functools import partial +import random from pathlib import Path +from typing import List import evaluate -import random - -# import nltk -import numpy as np import transformers import yaml from custom_datasets import get_one_dataset @@ -18,32 +15,55 @@ from torch.utils.data import ConcatDataset, Subset from torch.utils.data.sampler import Sampler -class ClassSampler(Sampler): - """Sampler which returns a fixed number of samples per class, per epoch""" - def __init__(self, class_labels, class_sizes): - self.class_labels = class_labels - self.class_sizes = class_sizes +class PerDatasetSampler(Sampler): + """Sampler which returns a fixed number of samples per dataset, per epoch. + + Example: + + Dataset 1 has 10,000 examples and we want 200 per epoch + Dataset 2 has 500 examples and we want all 500 per epoch + + Epoch size will be 700 and every epoch we'll sample a different + 200 from dataset 1. + + Parameters + ---------- + dataset_sizes : List[int] + A list with the size of each dataset. + dataset_size_per_epoch : List[int] + How many examples to get from each dataset per epoch. + + Note: dataset_sizes & dataset_size_per_epoch must be in the same order. + Further the examples in the underlying torch.utils.data.Dataset + must per ordered as dataset_1, dataset_2, ..., dataset_n. This is fine + if we concatenate a bunch of datasets together + e.g. using torch.utils.data.ConcatDataset which is current behaviour. + """ + + def __init__(self, dataset_sizes: List[int], dataset_size_per_epoch: List[int]): + self.dataset_sizes = dataset_sizes + self.dataset_size_per_epoch = dataset_size_per_epoch + self.num_datasets = len(dataset_sizes) def __iter__(self): - out = [] - for i, _class in enumerate(np.unique(self.class_labels)): - class_idx = np.argwhere(self.class_labels == _class).flatten() - sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i])) - out.extend(sampled_idx) - random.shuffle(out) - return iter(out) + epoch_idx = [] + n = 0 + for i in range(self.num_datasets): + sampled_idx = random.sample(range(n, self.dataset_sizes[i] + n), self.dataset_size_per_epoch[i]) + n += self.dataset_sizes[i] + epoch_idx.extend(sampled_idx) + random.shuffle(epoch_idx) + return iter(epoch_idx) def __len__(self): - return int(sum(self.class_sizes)) + return int(sum(self.dataset_size_per_epoch)) - -def build_train_sampler(training_conf, datasets): - train_sizes = [len(x) for x in datasets] - fractions = get_dataset_fractions(training_conf.datasets, train_sizes) - dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)] - dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)] - dataset_labels = [i for s in dataset_labels for i in s] - return ClassSampler(dataset_labels, dataset_size_per_epoch) + @classmethod + def build_sampler_from_config(cls, training_conf, datasets): + dataset_sizes = [len(x) for x in datasets] + fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes) + dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)] + return cls(dataset_sizes, dataset_size_per_epoch) def get_tokenizer(conf): @@ -157,13 +177,15 @@ def get_dataset_fractions(conf, dataset_sizes): dataset_name = get_dataset_name_from_data_config(data_config) if isinstance(data_config, dict): if "fraction" in data_config[dataset_name]: + if data_config[dataset_name]["fraction"] <= 0: + raise ValueError("Please specify fraction as a value between 0 < fraction <= 1") fractions.append(min(1, data_config[dataset_name]["fraction"])) elif "size" in data_config[dataset_name]: if data_config[dataset_name]["size"] > dataset_sizes[i]: - raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})") + raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}") fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) else: - raise ValueError("Please specify either fraction or size in config.yaml") + raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.") else: fractions.append(1) return fractions From 9faae250ce176c1025cf529003ceb3a5ef188a92 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Wed, 8 Feb 2023 20:57:13 +0000 Subject: [PATCH 03/10] minor tidy-up --- model/supervised_finetuning/trainer.py | 3 +-- model/supervised_finetuning/utils.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 922be10b..72ae9f04 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -210,8 +210,7 @@ if __name__ == "__main__": wandb.init( project="supervised-finetuning", - # entity=training_conf.wandb_entity, - entity="maw501", + entity=training_conf.wandb_entity, name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", ) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index c18cb380..6c47bdb7 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -172,6 +172,7 @@ def get_dataset_name_from_data_config(data_config): def get_dataset_fractions(conf, dataset_sizes): + """Calculate fraction of each dataset to use per epoch when subsampling""" fractions = [] for i, data_config in enumerate(conf): dataset_name = get_dataset_name_from_data_config(data_config) From ef548edb728c941b92f66432c0c0b8a1a44cd6c4 Mon Sep 17 00:00:00 2001 From: notmd <33456881+notmd@users.noreply.github.com> Date: Thu, 9 Feb 2023 15:20:27 +0700 Subject: [PATCH 04/10] Trollboard expandable (#1354) * wip * hide necessary column in trollboard * remove console.log * fix build * clean up * remove commented code --- .../components/{ => DataTable}/DataTable.tsx | 26 +++++- .../DataTable/jsonExpandRowModel.tsx | 60 +++++++++++++ .../LeaderboardTable/LeaderboardTable.tsx | 33 ++++--- .../LeaderboardTable/TrollboardTable.tsx | 88 ++++++++++--------- .../LeaderboardTable/useBoardRowProps.ts | 2 +- website/src/components/UserTable.tsx | 2 +- website/src/pages/admin/trollboard.tsx | 2 +- 7 files changed, 153 insertions(+), 60 deletions(-) rename website/src/components/{ => DataTable}/DataTable.tsx (90%) create mode 100644 website/src/components/DataTable/jsonExpandRowModel.tsx diff --git a/website/src/components/DataTable.tsx b/website/src/components/DataTable/DataTable.tsx similarity index 90% rename from website/src/components/DataTable.tsx rename to website/src/components/DataTable/DataTable.tsx index 35246bf7..3e75f32e 100644 --- a/website/src/components/DataTable.tsx +++ b/website/src/components/DataTable/DataTable.tsx @@ -23,13 +23,23 @@ import { Tr, useDisclosure, } from "@chakra-ui/react"; -import { Cell, ColumnDef, flexRender, getCoreRowModel, Row, useReactTable } from "@tanstack/react-table"; +import { + Cell, + ColumnDef, + ExpandedState, + flexRender, + getCoreRowModel, + getExpandedRowModel, + Row, + useReactTable, +} from "@tanstack/react-table"; import { Filter } from "lucide-react"; import { useTranslation } from "next-i18next"; -import { ChangeEvent, ReactNode } from "react"; +import { ChangeEvent, ReactNode, useState } from "react"; import { useDebouncedCallback } from "use-debounce"; -export type DataTableColumnDef = ColumnDef & { +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type DataTableColumnDef = ColumnDef & { filterable?: boolean; span?: number | ((cell: Cell) => number | undefined); }; @@ -54,6 +64,7 @@ export type DataTableProps = { disablePrevious?: boolean; disablePagination?: boolean; rowProps?: TableRowProps | DataTableRowPropsCallback; + getSubRows?: (row: T) => T[] | undefined; }; export const DataTable = ({ @@ -68,12 +79,21 @@ export const DataTable = ({ disablePrevious, disablePagination, rowProps, + getSubRows, }: DataTableProps) => { const { t } = useTranslation("leaderboard"); + const [expanded, setExpanded] = useState({}); + const { getHeaderGroups, getRowModel } = useReactTable({ data, columns, getCoreRowModel: getCoreRowModel(), + getExpandedRowModel: getExpandedRowModel(), + state: { + expanded, + }, + getSubRows, + onExpandedChange: setExpanded, }); const handleFilterChange = (value: FilterItem) => { diff --git a/website/src/components/DataTable/jsonExpandRowModel.tsx b/website/src/components/DataTable/jsonExpandRowModel.tsx new file mode 100644 index 00000000..84c70e20 --- /dev/null +++ b/website/src/components/DataTable/jsonExpandRowModel.tsx @@ -0,0 +1,60 @@ +import { Card, CardBody, Flex } from "@chakra-ui/react"; +import { Cell, CellContext } from "@tanstack/react-table"; +import { ChevronDown, ChevronRight } from "lucide-react"; + +type ExpandableRow = Omit & { + shouldExpand?: boolean; +}; + +export const createJsonExpandRowModel = () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const renderCell = ({ row, getValue }: CellContext, any>) => { + if (!row.original.shouldExpand) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { shouldExpand, ...res } = row.original; + return ( + + +
{JSON.stringify(res, null, 2)}
+
+
+ ); + } + + return ( + + {row.getCanExpand() ? ( + + ) : null}{" "} + {getValue()} + + ); + }; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const span = (cell: Cell, any>) => + cell.row.original.shouldExpand ? undefined : cell.row.getVisibleCells().length; + + const getSubRows = (row: ExpandableRow) => + row.shouldExpand + ? [ + { + ...row, + shouldExpand: false, + }, + ] + : undefined; + + const toExpandable = function (arr: T[] | undefined, val = true): ExpandableRow[] { + return !arr ? [] : arr.map((element) => ({ ...element, shouldExpand: val })); + }; + + return { renderCell, span, getSubRows, toExpandable }; +}; diff --git a/website/src/components/LeaderboardTable/LeaderboardTable.tsx b/website/src/components/LeaderboardTable/LeaderboardTable.tsx index d59fc902..fc3ae099 100644 --- a/website/src/components/LeaderboardTable/LeaderboardTable.tsx +++ b/website/src/components/LeaderboardTable/LeaderboardTable.tsx @@ -2,19 +2,20 @@ import { Box, CircularProgress, Flex, Link, useColorModeValue } from "@chakra-ui import { createColumnHelper } from "@tanstack/react-table"; import { MoreHorizontal } from "lucide-react"; import NextLink from "next/link"; -import { useSession } from "next-auth/react"; import { useTranslation } from "next-i18next"; import React, { useMemo } from "react"; +import { useHasRole } from "src/hooks/auth/useHasRole"; import { LeaderboardEntity, LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard"; -import { DataTable, DataTableColumnDef } from "../DataTable"; +import { DataTable, DataTableColumnDef } from "../DataTable/DataTable"; +import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel"; import { useBoardPagination } from "./useBoardPagination"; import { useBoardRowProps } from "./useBoardRowProps"; import { useFetchBoard } from "./useFetchBoard"; type WindowLeaderboardEntity = LeaderboardEntity & { isSpaceRow?: boolean }; const columnHelper = createColumnHelper(); - +const jsonExpandRowModel = createJsonExpandRowModel(); /** * Presents a grid of leaderboard entries with more detailed information. */ @@ -39,17 +40,24 @@ export const LeaderboardTable = ({ } = useFetchBoard( `/api/leaderboard?time_frame=${timeFrame}&limit=${limit}&includeUserStats=${!hideCurrentUserRanking}` ); - const { data: session } = useSession(); - const isAdmin = session?.user?.role === "admin"; + const isAdmin = useHasRole("admin"); + const columns: DataTableColumnDef[] = useMemo( () => [ { ...columnHelper.accessor("rank", { header: t("rank"), - cell: ({ row, getValue }) => (row.original.isSpaceRow ? : getValue()), + cell: (ctx) => + ctx.row.original.isSpaceRow ? ( + + ) : isAdmin ? ( + jsonExpandRowModel.renderCell(ctx) + ) : ( + ctx.getValue() + ), }), - span: (cell) => (cell.row.original.isSpaceRow ? 6 : undefined), + span: (cell) => (cell.row.original.isSpaceRow ? 6 : jsonExpandRowModel.span(cell)), }, columnHelper.accessor("display_name", { header: t("user"), @@ -82,17 +90,17 @@ export const LeaderboardTable = ({ data: paginatedData, end, ...pagnationProps - } = useBoardPagination({ rowPerPage, data: reply?.leaderboard, limit }); - const data: WindowLeaderboardEntity[] = useMemo(() => { - if (hideCurrentUserRanking || !reply?.user_stats_window) { + } = useBoardPagination({ rowPerPage, data: jsonExpandRowModel.toExpandable(reply?.leaderboard || []), limit }); + const data = useMemo(() => { + if (hideCurrentUserRanking || !reply?.user_stats_window || reply.user_stats_window.length === 0) { return paginatedData; } - const userStatsWindow: WindowLeaderboardEntity[] = reply.user_stats_window; + const userStatsWindow: WindowLeaderboardEntity[] = jsonExpandRowModel.toExpandable(reply.user_stats_window); const userStats = userStatsWindow.find((stats) => stats.highlighted); if (userStats && userStats.rank > end) { paginatedData.push( { isSpaceRow: true } as WindowLeaderboardEntity, - ...reply.user_stats_window.filter( + ...userStatsWindow.filter( (stats) => paginatedData.findIndex((leaderBoardEntity) => leaderBoardEntity.user_id === stats.user_id) === -1 ) // filter to avoid duplicated row ); @@ -116,6 +124,7 @@ export const LeaderboardTable = ({ columns={columns} caption={lastUpdated} rowProps={rowProps} + getSubRows={jsonExpandRowModel.getSubRows} {...pagnationProps} > ); diff --git a/website/src/components/LeaderboardTable/TrollboardTable.tsx b/website/src/components/LeaderboardTable/TrollboardTable.tsx index 1b1ea118..0c971655 100644 --- a/website/src/components/LeaderboardTable/TrollboardTable.tsx +++ b/website/src/components/LeaderboardTable/TrollboardTable.tsx @@ -1,26 +1,42 @@ -import { Box, CircularProgress, Flex, Link } from "@chakra-ui/react"; +import { Box, CircularProgress, Flex, IconButton, Link, Tooltip } from "@chakra-ui/react"; import { createColumnHelper } from "@tanstack/react-table"; -import { ThumbsDown, ThumbsUp } from "lucide-react"; +import { Mail, ThumbsDown, ThumbsUp, User } from "lucide-react"; import NextLink from "next/link"; import { FetchTrollBoardResponse, TrollboardEntity, TrollboardTimeFrame } from "src/types/Trollboard"; -import { DataTable } from "../DataTable"; +import { DataTable, DataTableColumnDef } from "../DataTable/DataTable"; +import { createJsonExpandRowModel } from "../DataTable/jsonExpandRowModel"; +import { Discord } from "../Icons/Discord"; import { useBoardPagination } from "./useBoardPagination"; import { useBoardRowProps } from "./useBoardRowProps"; import { useFetchBoard } from "./useFetchBoard"; + const columnHelper = createColumnHelper(); - const toPercentage = (num: number) => `${Math.round(num * 100)}%`; +const jsonExpandRowModel = createJsonExpandRowModel(); -const columns = [ - columnHelper.accessor("rank", {}), +const columns: DataTableColumnDef[] = [ + { + ...columnHelper.accessor("rank", { + cell: jsonExpandRowModel.renderCell, + }), + span: jsonExpandRowModel.span, + }, columnHelper.accessor("display_name", { header: "Display name", - cell: ({ getValue, row }) => ( - - {getValue()} - - ), + cell: ({ getValue, row }) => { + const isEmail = row.original.auth_method === "local"; + return ( + + + {getValue()} + + + {isEmail ? : } + + + ); + }, }), columnHelper.accessor("troll_score", { header: "Troll score", @@ -45,36 +61,19 @@ const columns = [ columnHelper.accessor((row) => row.spam + row.spam_prompts, { header: "Spam", }), - columnHelper.accessor("lang_mismach", { - header: "Lang mismach", - }), - columnHelper.accessor("not_appropriate", { - header: "Not appropriate", - }), - columnHelper.accessor("pii", {}), - columnHelper.accessor("hate_speech", { - header: "Hate speech", - }), - columnHelper.accessor("sexual_content", { - header: "Sexual Content", - }), - columnHelper.accessor("political_content", { - header: "Political Content", - }), - columnHelper.accessor("quality", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("helpfulness", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("humor", { - cell: ({ getValue }) => toPercentage(getValue()), - }), - columnHelper.accessor("violence", { - cell: ({ getValue }) => toPercentage(getValue()), - }), columnHelper.accessor("toxicity", { - cell: ({ getValue }) => toPercentage(getValue()), + cell: ({ getValue }) => toPercentage(getValue() || 0), + }), + columnHelper.accessor((row) => row.user_id, { + header: "Actions", + cell: ({ row }) => ( + } + > + ), }), ]; @@ -94,7 +93,11 @@ export const TrollboardTable = ({ lastUpdated, } = useFetchBoard(`/api/admin/trollboard?time_frame=${timeFrame}&limit=${limit}`); - const { data, ...paginationProps } = useBoardPagination({ rowPerPage, data: trollboardRes?.trollboard, limit }); + const { data, ...paginationProps } = useBoardPagination({ + rowPerPage, + data: jsonExpandRowModel.toExpandable(trollboardRes?.trollboard), + limit, + }); const rowProps = useBoardRowProps(); if (isLoading) { return ; @@ -112,11 +115,12 @@ export const TrollboardTable = ({ }, }} > - + diff --git a/website/src/components/LeaderboardTable/useBoardRowProps.ts b/website/src/components/LeaderboardTable/useBoardRowProps.ts index 32f3fc56..be0fe7de 100644 --- a/website/src/components/LeaderboardTable/useBoardRowProps.ts +++ b/website/src/components/LeaderboardTable/useBoardRowProps.ts @@ -2,7 +2,7 @@ import { useColorModeValue, useToken } from "@chakra-ui/react"; import { useCallback } from "react"; import { colors } from "src/styles/Theme/colors"; -import { DataTableRowPropsCallback } from "../DataTable"; +import { DataTableRowPropsCallback } from "../DataTable/DataTable"; export const useBoardRowProps = () => { const borderColor = useToken("colors", useColorModeValue(colors.light.active, colors.dark.active)); diff --git a/website/src/components/UserTable.tsx b/website/src/components/UserTable.tsx index ab05d065..5c51c585 100644 --- a/website/src/components/UserTable.tsx +++ b/website/src/components/UserTable.tsx @@ -7,7 +7,7 @@ import { get } from "src/lib/api"; import type { FetchUsersResponse, User } from "src/types/Users"; import useSWR from "swr"; -import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable"; +import { DataTable, DataTableColumnDef, FilterItem } from "./DataTable/DataTable"; interface Pagination { /** diff --git a/website/src/pages/admin/trollboard.tsx b/website/src/pages/admin/trollboard.tsx index aaef80f3..47e69d6b 100644 --- a/website/src/pages/admin/trollboard.tsx +++ b/website/src/pages/admin/trollboard.tsx @@ -18,7 +18,7 @@ const Leaderboard = () => { - {t("leaderboard")} + Trollboard From 7b16ee9a7568e5ae92f8edf4fd4f8e74fda8d0c3 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Thu, 9 Feb 2023 19:29:11 +0900 Subject: [PATCH 05/10] When loading recent messages, filter by the user's stored language. (#1292) A simple short term fix. Applies to #962. To fully fix a small backend change is needed. --- website/public/locales/en/message.json | 2 +- website/src/lib/oasst_api_client.ts | 4 ++-- website/src/pages/api/messages/index.ts | 4 +++- website/src/pages/messages/index.tsx | 7 ++++++- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/website/public/locales/en/message.json b/website/public/locales/en/message.json index a531f0a4..75565656 100644 --- a/website/public/locales/en/message.json +++ b/website/public/locales/en/message.json @@ -8,7 +8,7 @@ "open_new_tab_action": "Open in new tab", "parent": "Parent", "reactions": "Reactions", - "recent_messages": "Recent Messages", + "recent_messages": "Recent Messages in {{language}}", "report_action": "Report", "report_placeholder": "Why should this message be reviewed?", "report_title": "Report", diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index 826fc195..2ec76c85 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -314,8 +314,8 @@ export class OasstApiClient { return this.get(`/api/v1/messages?${params}`); } - fetch_recent_messages() { - return this.get(`/api/v1/messages`); + fetch_recent_messages(lang: string) { + return this.get(`/api/v1/messages`, { lang }); } fetch_message_children(messageId: string) { diff --git a/website/src/pages/api/messages/index.ts b/website/src/pages/api/messages/index.ts index fbcaee3c..978ed3ff 100644 --- a/website/src/pages/api/messages/index.ts +++ b/website/src/pages/api/messages/index.ts @@ -1,9 +1,11 @@ import { withoutRole } from "src/lib/auth"; import { createApiClient } from "src/lib/oasst_client_factory"; +import { getUserLanguage } from "src/lib/users"; const handler = withoutRole("banned", async (req, res, token) => { const client = await createApiClient(token); - const messages = await client.fetch_recent_messages(); + const userLanguage = getUserLanguage(req); + const messages = await client.fetch_recent_messages(userLanguage); res.status(200).json(messages); }); diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx index 8d950a2b..4cae792a 100644 --- a/website/src/pages/messages/index.tsx +++ b/website/src/pages/messages/index.tsx @@ -1,6 +1,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; import { useTranslation } from "next-i18next"; +import { useCookies } from "react-cookie"; import { getDashboardLayout } from "src/components/Layout"; import { MessageTable } from "src/components/Messages/MessageTable"; import { get } from "src/lib/api"; @@ -15,6 +16,8 @@ const MessagesDashboard = () => { const { data: messages } = useSWRImmutable("/api/messages", get, { revalidateOnMount: true }); const { data: userMessages } = useSWRImmutable(`/api/messages/user`, get, { revalidateOnMount: true }); + const [cookies] = useCookies(["NEXT_LOCALE"]); + const currentLanguage = cookies["NEXT_LOCALE"] || "en"; return ( <> @@ -24,7 +27,9 @@ const MessagesDashboard = () => { - {t("recent_messages")} + {t("recent_messages", { + language: new Intl.DisplayNames([currentLanguage], { type: "language" }).of(currentLanguage), + })} Date: Thu, 9 Feb 2023 17:31:04 +0700 Subject: [PATCH 06/10] fix message reaction count visibility (#1382) close #1367 * Only show when the user is author, admin or has reacted to the message for all pages. * Add disabled state when user is author --- .../Messages/MessageEmojiButton.stories.tsx | 15 ++++----- .../Messages/MessageEmojiButton.tsx | 31 ++++++++++++++++--- .../components/Messages/MessageTableEntry.tsx | 3 +- website/src/types/Conversation.ts | 8 +++++ 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/website/src/components/Messages/MessageEmojiButton.stories.tsx b/website/src/components/Messages/MessageEmojiButton.stories.tsx index b083c966..5d3e8be6 100644 --- a/website/src/components/Messages/MessageEmojiButton.stories.tsx +++ b/website/src/components/Messages/MessageEmojiButton.stories.tsx @@ -11,17 +11,16 @@ export default { const Template = ({ emoji, count, - checked, - showCount, + ...rest }: { emoji: string; count: number; checked?: boolean; - showCount: boolean; + userIsAuthor: boolean; + disabled?: boolean; + userReacted: boolean; }) => { - return ( - - ); + return ; }; export const Default = Template.bind({}); @@ -29,7 +28,9 @@ Default.args = { emoji: "+1", count: 7, checked: false, - showCount: true, + userIsAuthor: false, + disabled: false, + userReacted: true, }; export const BigNumber = Template.bind({}); diff --git a/website/src/components/Messages/MessageEmojiButton.tsx b/website/src/components/Messages/MessageEmojiButton.tsx index f140a789..e3acb3c0 100644 --- a/website/src/components/Messages/MessageEmojiButton.tsx +++ b/website/src/components/Messages/MessageEmojiButton.tsx @@ -1,4 +1,5 @@ import { Button } from "@chakra-ui/react"; +import { useHasRole } from "src/hooks/auth/useHasRole"; import { MessageEmoji } from "src/types/Conversation"; import { emojiIcons } from "src/types/Emoji"; @@ -6,12 +7,27 @@ interface MessageEmojiButtonProps { emoji: MessageEmoji; checked?: boolean; onClick: () => void; - showCount: boolean; + userIsAuthor: boolean; + disabled?: boolean; + userReacted: boolean; } -export const MessageEmojiButton = ({ emoji, checked, onClick, showCount }: MessageEmojiButtonProps) => { +export const MessageEmojiButton = ({ + emoji, + checked, + onClick, + userIsAuthor, + disabled, + userReacted, +}: MessageEmojiButtonProps) => { const EmojiIcon = emojiIcons.get(emoji.name); - if (!EmojiIcon) return <>; + const isAdmin = useHasRole("admin"); + + if (!EmojiIcon) return null; + + const isDisabled = !!(userIsAuthor ? true : disabled); + const showCount = (emoji.count > 0 && userReacted) || userIsAuthor || isAdmin; + return ( ); }; diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index a82cc50c..c469de70 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -116,7 +116,8 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE key={emoji} emoji={{ name: emoji, count }} checked={emojiState.user_emojis.includes(emoji)} - showCount={emojiState.user_emojis.filter((emoji) => emoji === "+1" || emoji === "-1").length > 0} + userReacted={emojiState.user_emojis.length > 0} + userIsAuthor={message.user_is_author} onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))} /> ); diff --git a/website/src/types/Conversation.ts b/website/src/types/Conversation.ts index 5eb86351..89713107 100644 --- a/website/src/types/Conversation.ts +++ b/website/src/types/Conversation.ts @@ -19,6 +19,14 @@ export interface Message extends MessageEmojis { parent_id: string; frontend_message_id?: string; user_id: string; + user_is_author: boolean | null; + deleted: boolean | null; + synthetic: boolean | null; + message_tree_id: string; + ranking_count: number | null; + rank: number | null; + model_name: string | null; + review_count: number | null; } export interface Conversation { From 0187d4c17af771f9d5900480a9a2ff5878072171 Mon Sep 17 00:00:00 2001 From: Mehdi Zibout <113536090+mehdi-zibout@users.noreply.github.com> Date: Thu, 9 Feb 2023 11:31:43 +0100 Subject: [PATCH 07/10] Fix avatar not displaying correctly in the navbar (#1383) --- website/src/components/Header/UserMenu.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/components/Header/UserMenu.tsx b/website/src/components/Header/UserMenu.tsx index 8b5de035..cad2dbe8 100644 --- a/website/src/components/Header/UserMenu.tsx +++ b/website/src/components/Header/UserMenu.tsx @@ -69,7 +69,7 @@ export function UserMenu() { - + {session.user.name || "New User"} From e8f1bd0737e30ca30a69e9b2136394287e1c95ef Mon Sep 17 00:00:00 2001 From: lytjedk <124836469+lytjedk@users.noreply.github.com> Date: Thu, 9 Feb 2023 10:54:26 +0000 Subject: [PATCH 08/10] Create dashboard.json (#1381) --- website/public/locales/da/dashboard.json | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 website/public/locales/da/dashboard.json diff --git a/website/public/locales/da/dashboard.json b/website/public/locales/da/dashboard.json new file mode 100644 index 00000000..79c790ef --- /dev/null +++ b/website/public/locales/da/dashboard.json @@ -0,0 +1,8 @@ +{ + "grab_a_task": "Tag en opgave!", + "create": "Lav", + "evaluate": "Evaluer", + "label": "Label", + "dashboard": "Dashboard", + "go": "Start" +} From f3debae3ca566c089a75c441cd2d336b16cf47be Mon Sep 17 00:00:00 2001 From: CactiStaccingCrane <121211419+CactiStaccingCrane@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:44:31 +0700 Subject: [PATCH 09/10] Fix Vietnamese translation (#1384) --- website/public/locales/vi/common.json | 13 +++-------- website/public/locales/vi/index.json | 8 +++---- website/public/locales/vi/leaderboard.json | 6 ++--- website/public/locales/vi/message.json | 10 ++++----- website/public/locales/vi/side_menu.json | 12 ++++++++++ website/public/locales/vi/tasks.json | 26 +++++++++++----------- 6 files changed, 40 insertions(+), 35 deletions(-) create mode 100644 website/public/locales/vi/side_menu.json diff --git a/website/public/locales/vi/common.json b/website/public/locales/vi/common.json index 71309549..aa2cc855 100644 --- a/website/public/locales/vi/common.json +++ b/website/public/locales/vi/common.json @@ -3,21 +3,17 @@ "account_settings": "Tài khoản", "admin_dashboard": "Trang cho admin", "connect": "Liên hệ", - "conversational": "Chatbot AI cho tất cả mọi người", - "copied": "Copied", + "conversational": "Chatbot AI cho mọi người", + "copied": "Đã sao chép", "dark_mode": "Giao diện tối", - "dashboard_home": "Trang chính", "dashboard": "Trang chính", "delete": "Xoá", "discord": "Discord", "docs": "Hướng dẫn", "github": "GitHub", - "leaderboard": "Bảng xếp hạng", "legal": "Luật lệ", "light_mode": "Giao diện sáng", "loading": "Đang tải...", - "messages_dashboard": "Bảng xếp hạng tin nhắn", - "messages": "Tin nhắn", "more_information": "Xem thêm", "no": "Không", "privacy_policy": "Chính sách bảo mật", @@ -26,11 +22,8 @@ "sign_out": "Đăng xuất", "status_dashboard": "Trang hiện tình trạng", "status": "Tình trạng", - "success": "Success", + "success": "Thành công", "terms_of_service": "Điều khoản sử dụng", "title": "Open Assistant", - "user_leaderboard": "Bảng xếp hạng người dùng", - "users_dashboard": "Trang về người dùng", - "users": "Người dùng", "yes": "Có" } diff --git a/website/public/locales/vi/index.json b/website/public/locales/vi/index.json index 43022145..1c6b5e19 100644 --- a/website/public/locales/vi/index.json +++ b/website/public/locales/vi/index.json @@ -1,15 +1,15 @@ { "blurb": "Đây sẽ là cuộc cách mạng công nghệ mới.", - "blurb1": "Giống như cách Stable Diffusion đã cho mọi người công cụ để làm tranh ảnh bằng AI, Open Assistant sẽ làm như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.", - "description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện viên trên toàn thế giới.", + "blurb1": "Giống như Stable Diffusion với tranh ảnh bằng AI, Open Assistant sẽ làm tương tự như vậy với con chatbot mã nguồn mở mạnh nhất thế giới.", + "description": "Chatbot trí tuệ nhân tạo mã nguồn mở, dựa trên mô hình ngôn ngữ lớn của LAION và các tình nguyện.", "faq_items": { "q0": "Open Assistant bây giờ thế nào rồi?", "a0": "Dự án này đang trong giai đoạn phát triển, từ những nghiên cứu về sử dụng RLHF (học từ phản hồi con người) trong các mô hình ngôn ngữ lớn.", "q1": "Open Assistant được phát triển bởi ai?", - "a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên trên toàn thế giới." + "a1": "Open Assistant là dự ản được phát triển bởi LAION and các tình nguyện viên." }, "faq_title": "Câu hỏi", "join_us_description": "Các dự án mã nguồn mở được phát triển bởi những người như bạn. Triết lý mã nguồn mở là hợp tác để tạo và phát triển công nghệ mới mà làm giàu thế giới quanh ta. Bạn có muốn tham gia không? Liên hệ chúng tôi ở đây:", "join_us_title": "Tham gia", - "subtitle": "Chatbot AI cho tất cả mọi người" + "subtitle": "Chatbot AI cho mọi người" } diff --git a/website/public/locales/vi/leaderboard.json b/website/public/locales/vi/leaderboard.json index ebe25dda..d5b4e167 100644 --- a/website/public/locales/vi/leaderboard.json +++ b/website/public/locales/vi/leaderboard.json @@ -1,15 +1,15 @@ { "daily": "Ngày", - "label": "Nhãn", + "label": "Số nhãn", "last_updated_at": "Cập nhật lần cuối: {{val, datetime}}", "leaderboard": "Bảng xếp hạng", "monthly": "Tháng", "next": "Tiếp", "overall": "Tổng quan", "previous": "Trước", - "prompt": "Câu đầu tiên", + "prompt": "Số câu đầu", "rank": "Xếp hạng", - "reply": "Câu trả lời", + "reply": "Số câu trả lời", "score": "Điểm", "top_5_contributors_today": "Top 5 người đóng góp", "user": "Tên người dùng", diff --git a/website/public/locales/vi/message.json b/website/public/locales/vi/message.json index baabe4b6..45d53060 100644 --- a/website/public/locales/vi/message.json +++ b/website/public/locales/vi/message.json @@ -1,20 +1,20 @@ { - "copy_message_id": "Copy message ID", + "copy_message_id": "Sao chép ID", "label_action": "Nhãn", "label_title": "Nhãn", "message": "Tin nhắn", - "message_deleted": "Message deleted", + "message_deleted": "Tin nhắn đã xoá", "open_new_tab_action": "Mở ở trang mới", "parent": "Tin nhắn gốc", "reactions": "Bình luận", "recent_messages": "Tin nhắn gần đây", "report_action": "Báo cáo", - "report_placeholder": "Tại sao tin nhắn này cần được báo cáo?", + "report_placeholder": "Nêu lý do để báo cáo", "report_title": "Báo cáo", "send_report": "Gửi", - "stop_tree": "Stop tree", + "stop_tree": "Dừng nhánh tin nhắn", "submit_labels": "Gửi", - "tree_stopped": "Tree stopped {{id}}", + "tree_stopped": "Nhánh tin nhắn {{id}} đã dừng", "view_user": "Xem người dùng", "your_recent_messages": "Tin nhắn gần đây của bạn" } diff --git a/website/public/locales/vi/side_menu.json b/website/public/locales/vi/side_menu.json new file mode 100644 index 00000000..6b06846f --- /dev/null +++ b/website/public/locales/vi/side_menu.json @@ -0,0 +1,12 @@ +{ + "dashboard": "Trang chính", + "dashboard_home": "Trang chính", + "leaderboard": "Bảng xếp hạng", + "messages": "Tin nhắn", + "messages_dashboard": "Bảng xếp hạng tin nhắn", + "status": "Tình trạng", + "status_dashboard": "Trang hiện tình trạng", + "user_leaderboard": "Bảng xếp hạng người dùng", + "users": "Người dùng", + "users_dashboard": "Trang về người dùng" +} diff --git a/website/public/locales/vi/tasks.json b/website/public/locales/vi/tasks.json index e197c410..7884f5fe 100644 --- a/website/public/locales/vi/tasks.json +++ b/website/public/locales/vi/tasks.json @@ -1,22 +1,22 @@ { "available_task_count": "{{count}} việc", "classify_assistant_reply": { - "label": "Phân loại tin nhắn của Open Assistant", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "label": "Phân loại các tin nhắn của Open Assistant", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện." }, "classify_initial_prompt": { - "label": "Phân loại tin nhắn đầu", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "label": "Phân loại các tin nhắn đầu tiên", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Đọc tin nhắn đầu và trả lời các câu hỏi." }, "classify_prompter_reply": { "label": "Phân loại tin nhắn người dùng", - "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn.", + "desc": "Tạo nhãn dữ liệu để đánh giá tin nhắn.", "overview": "Từ cuộc trò truyện ở dưới, trả lời các câu hỏi về câu trả lời cuối trong cuộc trò truyện." }, "create_initial_prompt": { - "label": "Tạo tin nhắn đầu", + "label": "Tạo tin nhắn đầu tiên", "desc": "Viết tin nhắn đầu tiên để làm bộ dữ liệu cho Open Assistant.", "overview": "Viết tin nhắn đầu tiên để Open Assistant trả lời", "instruction": "Viết tin nhắn đầu", @@ -27,17 +27,17 @@ "unchanged_message": "Are you sure you would like to continue?" }, "label_assistant_reply": { - "label": "Tạo nhãn cho tin nhắn của Open Assistant", + "label": "Đánh giá các tin nhắn của Open Assistant", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của Open Assistant.", "overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau." }, "label_initial_prompt": { - "label": "Tạo nhãn cho tin nhắn đầu", + "label": "Đánh giá các tin nhắn đầu tiên", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn đầu.", "overview": "Tạo nhãn dữ liệu cho tin nhắn sau." }, "label_prompter_reply": { - "label": "Tạo nhãn cho tin nhắn người dùng", + "label": "Đánh giá các tin nhắn người dùng", "desc": "Tạo nhãn dữ liệu đánh giá tin nhắn của người dùng.", "overview": "Từ cuộc trò truyện ở dưới, tạo nhãn dữ liệu cho tin nhắn sau." }, @@ -46,28 +46,28 @@ "desc": "Giúp cải thiện Open Assistant bằng cách làm một việc ngẫu nhiên." }, "rank_assistant_replies": { - "label": "Xếp hạng câu trả lời của Open Assistant", + "label": "Xếp hạng các câu trả lời của Open Assistant", "desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời mà Open Assistant đưa ra.", "overview": "Từ những câu trả lời của Open Assistant, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "rank_initial_prompts": { - "label": "Xếp hạng tin nhắn đầu tiên", + "label": "Xếp hạng các tin nhắn đầu tiên", "desc": "Đánh giá độ chính xác và dễ đọc của các câu trả lời của tin nhắn đầu tiên.", "overview": "Từ những tin nhắn đầu sau, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "rank_user_replies": { - "label": "Xếp hạng câu trả lời của người dùng", + "label": "Xếp hạng các câu trả lời của người dùng", "desc": "Giúp cải thiện câu trả lời của Open Assistant.", "overview": "Từ những câu trả lời của người dùng, xếp hạng chúng theo chất lượng, tốt nhât ở trên, tệ nhất ở dưới.", "unchanged_title": "Chưa thay đổi thứ tự", "unchanged_message": "Bạn chưa thay đổi thứ tự tin nhắn. Bạn có chắc muốn lưu không?" }, "reply_as_assistant": { - "label": "Đóng vai Open Assistant", + "label": "Đóng vai trợ lý", "desc": "Giúp cải thiện câu trả lời của Open Assistant.", "overview": "Tạo câu trả lời phù hợp cho cuộc trò truyện dưới đây", "response_placeholder": "Viết vào đây..." From 1fd713bba353e002d92ea5f84385eff528a24188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 9 Feb 2023 13:15:53 +0100 Subject: [PATCH 10/10] fix lang filtering --- backend/oasst_backend/prompt_repository.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index f2119e3b..eb244d3b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -947,6 +947,9 @@ class PromptRepository: if deleted is not None: qry = qry.filter(Message.deleted == deleted) + if lang is not None: + qry = qry.filter(Message.lang == lang) + if desc: qry = qry.order_by(Message.created_date.desc(), Message.id.desc()) else: @@ -955,9 +958,6 @@ class PromptRepository: if limit is not None: qry = qry.limit(limit) - if lang is not None: - qry = qry.filter(Message.lang == lang) - return self._add_user_emojis_all(qry) def update_children_counts(self, message_tree_id: UUID):