Merge branch 'main' into 766_admin_enhancement

This commit is contained in:
notmd
2023-01-20 23:12:43 +07:00
63 changed files with 874 additions and 225 deletions
@@ -0,0 +1,26 @@
"""add ix_user_display_name_id
Revision ID: 4f26fec4d204
Revises: 0964ac95170d
Create Date: 2023-01-19 22:00:00
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4f26fec4d204"
down_revision = "7f0a28a156f4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_user_display_name_id", table_name="user")
# ### end Alembic commands ###
+17 -22
View File
@@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/", response_model=list[protocol.FrontEndUser])
def get_users(
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
def get_users_ordered_by_username(
api_client_id: Optional[UUID] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
gt: Optional[str] = None,
lt: Optional[str] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
search_text: Optional[str] = None,
auth_method: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/by_display_name")
def query_frontend_users_by_display_name(
search_text: str,
exact: bool = False,
api_client_id: UUID = None,
max_count: int = Query(20, gt=0, le=1000),
auth_method: str = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_by_display_name(
search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method
users = ur.query_users_ordered_by_username(
api_client_id=api_client_id,
gte_username=gte_username,
gt_id=gt_id,
lte_username=lte_username,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
+57 -3
View File
@@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/users/{user_id}", response_model=protocol.FrontEndUser)
@router.get("/by_username", response_model=list[protocol.FrontEndUser])
def get_users_ordered_by_username(
api_client_id: Optional[UUID] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
search_text: Optional[str] = None,
auth_method: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_ordered_by_username(
api_client_id=api_client_id,
gte_username=gte_username,
gt_id=gt_id,
lte_username=lte_username,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/by_display_name", response_model=list[protocol.FrontEndUser])
def get_users_ordered_by_display_name(
api_client_id: Optional[UUID] = None,
gte_display_name: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_display_name: Optional[str] = None,
lt_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_ordered_by_display_name(
api_client_id=api_client_id,
gte_display_name=gte_display_name,
gt_id=gt_id,
lte_display_name=lte_display_name,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/{user_id}", response_model=protocol.FrontEndUser)
def get_user(
user_id: UUID,
api_client_id: UUID = None,
@@ -31,7 +85,7 @@ def get_user(
return user.to_protocol_frontend_user()
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT)
def update_user(
user_id: UUID,
enabled: Optional[bool] = None,
@@ -46,7 +100,7 @@ def update_user(
ur.update_user(user_id, enabled, notes)
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
def delete_user(
user_id: UUID,
db: Session = Depends(deps.get_db),
+4 -1
View File
@@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel
class User(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
__table_args__ = (
Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),
Index("ix_user_display_name_id", "display_name", "id", unique=True),
)
id: Optional[UUID] = Field(
sa_column=sa.Column(
+76 -26
View File
@@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from sqlmodel import Session, and_, or_
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -135,13 +135,16 @@ class UserRepository:
self.db.add(user)
return user
def query_users(
def query_users_ordered_by_username(
self,
api_client_id: Optional[UUID] = None,
limit: Optional[int] = 20,
gt: Optional[str] = None,
lt: Optional[str] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
limit: Optional[int] = 100,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
@@ -150,34 +153,52 @@ class UserRepository:
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
users = self.db.query(User)
qry = self.db.query(User).order_by(User.username, User.id)
if api_client_id:
users = users.filter(User.api_client_id == api_client_id)
if gte_username is not None:
if gt_id:
qry = qry.filter(
or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id))
)
else:
qry = qry.filter(User.username >= gte_username)
elif gt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if lte_username is not None:
if lt_id:
qry = qry.filter(
or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id))
)
else:
qry = qry.filter(User.username <= lte_username)
elif lt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if auth_method:
users = users.filter(User.auth_method == auth_method)
qry = qry.filter(User.auth_method == auth_method)
if api_client_id:
qry = qry.filter(User.api_client_id == api_client_id)
users = users.order_by(User.id)
if gt:
users = users.filter(User.id > gt)
if lt:
users = users.filter(User.id < lt).order_by(None).order_by(User.id.desc())
if search_text:
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
qry = qry.filter(User.username.like(pattern))
if limit is not None:
users = users.limit(limit)
qry = qry.limit(limit)
return users.all()
return qry.all()
def query_users_by_display_name(
def query_users_ordered_by_display_name(
self,
search_text: str,
exact: Optional[bool] = False,
limit: Optional[int] = 20,
gte_display_name: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_display_name: Optional[str] = None,
lt_id: Optional[UUID] = None,
api_client_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
search_text: Optional[str] = None,
limit: Optional[int] = 100,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
@@ -186,11 +207,40 @@ class UserRepository:
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
qry = self.db.query(User).order_by(User.display_name)
qry = self.db.query(User).order_by(User.display_name, User.id)
if exact:
qry = qry.filter(User.display_name == search_text)
else:
if gte_display_name is not None:
if gt_id:
qry = qry.filter(
or_(
User.display_name > gte_display_name,
and_(User.display_name == gte_display_name, User.id > gt_id),
)
)
else:
qry = qry.filter(User.display_name >= gte_display_name)
elif gt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if lte_display_name is not None:
if lt_id:
qry = qry.filter(
or_(
User.display_name < lte_display_name,
and_(User.display_name == lte_display_name, User.id < lt_id),
)
)
else:
qry = qry.filter(User.display_name <= lte_display_name)
elif lt_id:
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
if auth_method:
qry = qry.filter(User.auth_method == auth_method)
if api_client_id:
qry = qry.filter(User.api_client_id == api_client_id)
if search_text:
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
qry = qry.filter(User.display_name.like(pattern))
@@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 32
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
eval_steps: 200
+20
View File
@@ -60,6 +60,26 @@ python trainer.py --configs defaults your-model-name --deepspeed
## Dataset choices
To specify which translation pair for
[WMT](https://huggingface.co/datasets/wmt19) and
[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply
add the supported language pair at the postfix
```
datasets:
- wmt2019_zh-en
- wmt2019_ru-en
- wmt2019_de-en
- ted_trans_nl-en
- ted_trans_de-ja
```
Currently only these languages are supported via prompt translation:
```
ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh
```
## Results
Experimental results in wandb
@@ -29,6 +29,14 @@ defaults:
- soda
- joke
- gsm8k
- dive_mt
- wmt2019_zh-en
- wmt2019_ru-en
- wmt2019_de-en
- ted_trans_nl-en
- ted_trans_de-ja
- instruct_tuning
- wmt2019_de-en
- samsum
- soda_dialogue
cache_dir: .cache
@@ -0,0 +1,27 @@
# Dataset collections overview:
currently dataset can be divided into 3 classes
- language knowledge
- summarization
- translation
- dialogue : don't let user know you are a robot
- STEM : knowledge about the world
- coding
- world knowledge <= ideally we want to handle this via prefix context
Issues and TODO:
- as dataset are growing, how can we update this section less
- ideally we can update the config yaml and new dataset will be download from
hub
- one possible idea is we upload the trasform format of these dataset to the
OA hub
@@ -1,11 +1,26 @@
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"]
SUMMARIZATION_DATASETS = [
"xsum",
"cnn_dailymail",
"samsum",
"multi_news",
"scitldr",
"billsum",
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
def train_val_dataset(dataset, val_split=0.2):
@@ -25,14 +40,34 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
val_name = "validation" if dataset_name not in ["billsum"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
val_name = "validation" if dataset_name not in ["billsum"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
train, eval = train_val_dataset(dataset, val_split=0.2)
elif "wmt2019" in dataset_name:
language_pair = dataset_name.split("_")[-1]
train = WMT2019(pair=language_pair, split="train")
eval = WMT2019(pair=language_pair, split="validation")
elif dataset_name == "dive_mt":
dataset = DiveMT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
@@ -42,6 +77,9 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
@@ -25,6 +25,7 @@ class DialogueDataCollator:
for feature_one in features:
assert len(feature_one) % 2 == 0, "Number of messages must be even"
# TODO: we should push this to dataset __getitem__
messages = [
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ x
@@ -1,3 +1,4 @@
import json
import os
from urllib.request import urlopen
@@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset):
we are ignoring results with multiple lines for now
"""
name = "prompt_dialogue"
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
def __init__(self, cache_dir) -> None:
@@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset):
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
class InstructionTuning(Dataset):
"""
We have seen some promising capabilities from instruction tuning
with the following mix of datasets that are derived from datasets
available online.
The files for this data are in json format as a list of tuples
where each tuple is (source,instruction_response_pair)
- instruction_tuning_dataset_alpha_part1.json
- instruction_tuning_dataset_alpha_part2.json
Not to be confused with unatural instruction
"""
name = "instruction_dataset"
url_part_2 = (
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json"
)
url_part_1 = (
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json"
)
def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)
self.pairs = []
for file_link in [self.url_part_1, self.url_part_2]:
basename = file_link.split("/")[-1]
instruction_tune_file = os.path.join(cache_dir, basename)
if not os.path.exists(instruction_tune_file):
with urlopen(file_link) as file:
content = file.read().decode()
with open(instruction_tune_file, "w", encoding="utf-8") as fout:
fout.write(content)
with open(instruction_tune_file, "r", encoding="utf-8") as f:
datasets = json.load(f)
for row in datasets:
_, response_pair = row
question, answer = response_pair.split("\n\n", maxsplit=1)
answer = answer.replace("<|endoftext|>", "").strip()
self.pairs.append((question, answer))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
@@ -1,11 +1,18 @@
"""
Open / close book QA datasets
"""
import json
import os
import re
from urllib.request import urlopen
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
@@ -75,6 +82,9 @@ class QADataset(Dataset):
class WebGPT(Dataset):
name = "webgpt"
def __init__(self) -> None:
super().__init__()
@@ -89,7 +99,9 @@ class WebGPT(Dataset):
self.index2question[len(self.index2question)] = question
# only keep the best answer
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
questions[question] = re_reference_remove.sub(
"", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
)
self.questions = questions
@@ -103,6 +115,9 @@ class WebGPT(Dataset):
class SODA(Dataset):
name = "soda"
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
@@ -207,8 +222,8 @@ class SODADialogue(Dataset):
class JokeExplaination(Dataset):
""" """
name = "joke"
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
def __init__(self, cache_dir) -> None:
@@ -240,3 +255,6 @@ class JokeExplaination(Dataset):
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
# https://huggingface.co/datasets/aquamuse
@@ -1,3 +1,6 @@
"""
Summarize different spectrum of documents
"""
import random
from datasets import load_dataset
@@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = {
}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
"cnn_dailymail": (
"cnn_dailymail",
"3.0.0",
),
"samsum": ("samsum",),
"xsum": ("xsum",),
"multi_news": ("multi_news",),
"scitldr": (
"scitldr",
"AIC",
),
"billsum": ("billsum",),
"reddit": ("reddit",),
"tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news
"debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum
}
summarization_name_mapping = {
@@ -29,6 +40,8 @@ summarization_name_mapping = {
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
"tldr_news": ("content", "headline"),
"debate_sum": ("Full-Document", "Extract"),
}
@@ -43,7 +56,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.name = dataset
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
@@ -0,0 +1,68 @@
"""
SFT dataset to reject toxic questions
"""
import random
from datasets import load_dataset
from torch.utils.data import Dataset
class ProsocialDialogueExplaination(Dataset):
name = "prosocial_explain"
TEMPLATE = [
# 0 : reply or sentence of interest, 1 : reason of caution
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"),
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
]
def __init__(self, split="train", cache_dir=".cache") -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]):
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
self.pairs.append(
(
prompt_template.format(row["context"], safety_annotation),
answer_template.format(safe_answer, safety_annotation),
)
)
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
class ProsocialDialogue(Dataset):
name = "prosocial_dialogue"
"""
ProsocialDialog, we set up a human-AI collaborative data creation framework,
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
provide prosocial responses to them. This approach allows us to circumvent
two substantial challenges:
(1) there are no available large-scale corpora of multiturn prosocial conversations
between humans
(2) asking humans to write unethical, toxic, or problematic utterances could result
in psychological harms (Roberts, 2017; Steiger et al., 2021).
"""
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
def __init__(self, split="train", cache_dir=".cache") -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for answer in row["rots"]:
self.pairs.append((self.PREFIX + row["context"], answer))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
@@ -0,0 +1,142 @@
"""
List of translation dataset
GroNLP/divemt
fill in the blanks : https://huggingface.co/datasets/m_lama
"""
import random
from datasets import load_dataset
from torch.utils.data import Dataset
# postfix prompt
TRANSLATION_PROMPT = {
"zh": [ # simplified or any chinese which was not mentioned
"Translate to chinese simplified: {}",
"{}, translate to chinese",
"{} give me the chinese translation",
"翻译成中文: {}",
"{} 这句中文翻译怎麽写?",
"我需要这句话的中文翻译: {}",
],
"zh-tw": [ # WMT code
"{}. Translate to chinese traditional",
"{}, translate to chinese",
"{}. get chinese translation",
"中文翻譯: {}",
"幫我翻譯成中文: '{}'",
"{} 這句中文翻譯怎麼寫?",
],
"ja": [
"{}: help me translate to japanese",
"Need japanese translation: {}",
"{}: にほんごやくをよこす",
"{}: にほんごやくをおくれ",
"{}: にほんごやくを じょす",
"give me the japanese translation, {}",
],
"de": [
"{}: translate to german",
"give me the german translation {}",
"I want german translation {}",
"{}, ins Deutsche übersetzen",
"{}, Übersetzen ins Deutsche",
],
"fr": [
"{}. translate to french",
"{} write in french",
"{} french translation",
"{} ,donnez moi la traduction française",
],
"ko": [
"{}. translate to Korean",
"how do we write in korean: {}",
"give me the korean translation: {}",
"{}, 한국어 번역을 해주세요",
],
"ms": [
"{} translate to malay",
"{} how do we write in Malay",
"{} give me the malay translation",
"{} , berikan saya terjemahan dalam bahasa melayu",
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
],
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
"tr": ["{}. türkçeye çevi̇ri̇n", "{} write in turkish", "turkish translation: '{}'", "türkçeye çevi̇rmek: {}"],
"it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"],
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
}
class TranslationPair(Dataset):
def __init__(self) -> None:
super().__init__()
self.pairs = []
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
return self.pairs[index]
class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split("-")
for row in dataset:
row = row["translation"]
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
class DiveMT(TranslationPair):
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
def __init__(self, split="train") -> None:
super().__init__()
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
# ISO 639-2
lang_code_2 = row["subject_id"].split("_")[0]
lang_code = self.REMAP[lang_code_2]
if lang_code not in TRANSLATION_PROMPT:
continue
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
lang_code = "en"
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt])
self.pairs.append((source, row[src]))
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset:
row = row["translation"]
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
@@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
def test_all_datasets():
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke"]
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
config = Namespace(cache_dir=".cache")
for dataset_name in others + qa_base + summarize_base:
for dataset_name in translation + others + summarize_base + qa_base:
print(dataset_name)
train, eval = get_one_dataset(config, dataset_name)
# sanity check
@@ -48,7 +49,3 @@ def test_collate_fn():
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
assert batch["targets"].shape[1] <= 512
if __name__ == "__main__":
test_collate_fn()
@@ -1,34 +1,27 @@
import { OasstApiClient, OasstError } from "src/lib/oasst_api_client";
import type { BackendUserCore } from "src/types/Users";
describe("Contract test for Oasst API", function () {
// Assumes this is running the mock server.
const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
const testUser = {
id: "abcd",
display_name: "test",
auth_method: "local",
} as BackendUserCore;
it("can fetch a task", async () => {
expect(
await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
})
).to.be.not.null;
expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null;
});
it("can ack a task", async () => {
const task = await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
});
const task = await oasstApiClient.fetchTask("random", testUser);
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
});
it("can record a taskInteraction", async () => {
const task = await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
});
const task = await oasstApiClient.fetchTask("random", testUser);
expect(
await oasstApiClient.interactTask(
"text_reply_to_message",
@@ -36,11 +29,7 @@ describe("Contract test for Oasst API", function () {
"321",
"1",
{ text: "Test" },
{
sub: "test",
name: "test",
email: "test",
}
testUser
)
).to.be.not.null;
});
+14 -1
View File
@@ -1,4 +1,17 @@
{
"about": "About",
"account_settings": "Account",
"connect": "Connect",
"conversational": "Conversational AI for everyone.",
"dashboard": "Dashboard",
"discord": "Discord",
"github": "GitHub"
"docs": "Docs",
"github": "GitHub",
"legal": "Legal",
"privacy_policy": "Privacy Policy",
"report_a_bug": "Report a Bug",
"sign_in": "Sign In",
"sign_out": "Sign Out",
"terms_of_service": "Terms of Service",
"title": "Open Assistant"
}
@@ -1,19 +1,19 @@
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Link from "next/link";
import { TaskTypes } from "../Tasks/TaskTypes";
import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes";
export const TaskOption = ({ displayTaskCategories }) => {
export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => {
const backgroundColor = useColorModeValue("white", "gray.700");
return (
<Box className="flex flex-col gap-14">
{displayTaskCategories.map((category, categoryIndex) => (
<div key={categoryIndex}>
<Text className="text-2xl font-bold pb-4">{category}</Text>
{displayTaskCategories.map((category) => (
<div key={category}>
<Text className="text-2xl font-bold pb-4">{TaskCategoryLabels[category]}</Text>
<SimpleGrid columns={[1, 1, 2, 2, 3, 4]} gap={4}>
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
{TaskTypes.filter((task) => task.category === category).map((item) => (
<Link key={category + item.label} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
+3 -5
View File
@@ -25,8 +25,8 @@ import clsx from "clsx";
import { useEffect, useReducer } from "react";
import { FiAlertCircle } from "react-icons/fi";
import { get, post } from "src/lib/api";
import { Message } from "src/types/Conversation";
import { colors } from "src/styles/Theme/colors";
import { Message } from "src/types/Conversation";
import useSWR from "swr";
import useSWRMutation from "swr/mutation";
@@ -114,9 +114,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
}, [data, isLoading]);
const { trigger } = useSWRMutation("/api/set_label", post, {
onSuccess: () => {
setIsEditing.off();
},
onSuccess: setIsEditing.off,
});
const submitResponse = () => {
@@ -149,7 +147,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
isLazy
lazyBehavior="keepMounted"
>
<Box display="flex" alignItems="center" gap="2">
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
<PopoverAnchor>{props.children}</PopoverAnchor>
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
+12 -10
View File
@@ -1,9 +1,11 @@
import { Box, Divider, Flex, Text, useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useTranslation } from "next-i18next";
import { useMemo } from "react";
export function Footer() {
const { t } = useTranslation();
const { colorMode } = useColorMode();
const backgroundColor = colorMode === "light" ? "white" : "gray.800";
const textColor = colorMode === "light" ? "black" : "gray.300";
@@ -33,10 +35,10 @@ export function Footer() {
<Box>
<Text fontSize="md" fontWeight="bold">
Open Assistant
{t("title")}
</Text>
<Text fontSize="sm" color="gray.500">
Conversational AI for everyone.
{t("conversational")}
</Text>
</Box>
</Flex>
@@ -45,23 +47,23 @@ export function Footer() {
<Box display="flex" flexDirection={["column", "row"]} gap={["6", "14"]} fontSize="sm">
<Flex direction="column" alignItems={["center", "start"]}>
<Text fontWeight="bold" color={textColor}>
Legal
{t("legal")}
</Text>
<FooterLink href="/privacy-policy" label="Privacy Policy" />
<FooterLink href="/terms-of-service" label="Terms of Service" />
<FooterLink href="/privacy-policy" label={t("privacy_policy")} />
<FooterLink href="/terms-of-service" label={t("terms_of_service")} />
</Flex>
<Flex direction="column" alignItems={["center", "start"]}>
<Text fontWeight="bold" color={textColor}>
Connect
{t("connect")}
</Text>
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label={t("github")} />
<FooterLink href="https://ykilcher.com/open-assistant-discord" label={t("discord")} />
</Flex>
<Flex direction="column" alignItems={["center", "start"]}>
<Text fontWeight="bold" color={textColor}>
About
{t("about")}
</Text>
<FooterLink href="https://projects.laion.ai/Open-Assistant" label="Docs" />
<FooterLink href="https://projects.laion.ai/Open-Assistant" label={t("docs")} />
</Flex>
</Box>
</nav>
+5 -3
View File
@@ -1,7 +1,8 @@
import { Box, Button, Text, Flex } from "@chakra-ui/react";
import { Box, Button, Flex, Text } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { useTranslation } from "next-i18next";
import { Flags } from "react-feature-flags";
import { FaUser } from "react-icons/fa";
@@ -23,7 +24,8 @@ function AccountButton() {
);
}
export function Header(props) {
export function Header() {
const { t } = useTranslation();
const { data: session } = useSession();
const homeURL = session ? "/dashboard" : "/";
@@ -34,7 +36,7 @@ export function Header(props) {
<Flex alignItems="center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
<Text fontFamily="inter" fontSize="2xl" fontWeight="bold" ml="3">
Open Assistant
{t("title")}
</Text>
</Flex>
</Link>
+11 -9
View File
@@ -13,6 +13,7 @@ import {
} from "@chakra-ui/react";
import NextLink from "next/link";
import { signOut, useSession } from "next-auth/react";
import { useTranslation } from "next-i18next";
import React, { ElementType, useCallback } from "react";
import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi";
@@ -25,6 +26,7 @@ interface MenuOption {
}
export function UserMenu() {
const { t } = useTranslation();
const borderColor = useColorModeValue("gray.300", "gray.600");
const handleSignOut = useCallback(() => {
signOut({ callbackUrl: "/" });
@@ -36,23 +38,23 @@ export function UserMenu() {
}
const options: MenuOption[] = [
{
name: "Dashboard",
name: t("dashboard"),
href: "/dashboard",
desc: "Dashboard",
desc: t("dashboard"),
icon: FiLayout,
isExternal: false,
},
{
name: "Account Settings",
name: t("account_settings"),
href: "/account",
desc: "Account Settings",
desc: t("account_settings"),
icon: FiSettings,
isExternal: false,
},
{
name: "Report a Bug",
name: t("report_a_bug"),
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
desc: "Report a Bug",
desc: t("report_a_bug"),
icon: FiAlertTriangle,
isExternal: true,
},
@@ -60,9 +62,9 @@ export function UserMenu() {
if (session.user.role === "admin") {
options.unshift({
name: "Admin Dashboard",
name: t("admin_dashboard"),
href: "/admin",
desc: "Admin Dashboard",
desc: t("admin_dashboard"),
icon: FiShield,
isExternal: false,
});
@@ -105,7 +107,7 @@ export function UserMenu() {
<MenuDivider />
<MenuItem gap="3" borderRadius="md" p="4" onClick={handleSignOut}>
<FiLogOut className="text-blue-500" aria-hidden="true" />
<Text>Sign Out</Text>
<Text>{t("sign_out")}</Text>
</MenuItem>
</MenuList>
</Menu>
+1 -1
View File
@@ -2,8 +2,8 @@ import { Box, Text, useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import { useTranslation } from "next-i18next";
import { Container } from "./Container";
import { AnimatedCircles } from "./AnimatedCircles";
import { Container } from "./Container";
export function Hero() {
const { t } = useTranslation("index");
+3 -3
View File
@@ -23,7 +23,7 @@ export const getDefaultLayout = (page: React.ReactElement) => (
export const getTransparentHeaderLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
<Header />
{page}
<Footer />
</div>
@@ -31,7 +31,7 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
export const getDashboardLayout = (page: React.ReactElement) => (
<Grid templateRows="min-content 1fr" h="full">
<Header transparent={true} />
<Header />
<SideMenuLayout
menuButtonOptions={[
{
@@ -66,7 +66,7 @@ export const getDashboardLayout = (page: React.ReactElement) => (
export const getAdminLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
<Header />
<SideMenuLayout
menuButtonOptions={[
{
@@ -9,7 +9,7 @@ interface MessageTableProps {
export function MessageTable({ messages, enableLink }: MessageTableProps) {
return (
<Stack spacing="3">
<Stack spacing="4">
{messages.map((item) => (
<MessageTableEntry enabled={enableLink} item={item} key={item.id + item.frontend_message_id} />
))}
@@ -1,6 +1,8 @@
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import Link from "next/link";
import { useRouter } from "next/router";
import { useCallback, useMemo } from "react";
import { FlaggableElement } from "src/components/FlaggableElement";
import { Message } from "src/types/Conversation";
@@ -10,47 +12,48 @@ interface MessageTableEntryProps {
}
export function MessageTableEntry(props: MessageTableEntryProps) {
const router = useRouter();
const { item } = props;
const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]);
const backgroundColor = useColorModeValue("gray.100", "gray.700");
const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B");
const avatarColor = useColorModeValue("white", "black");
const borderColor = useColorModeValue("blackAlpha.200", "whiteAlpha.200");
const inlineAvatar = useBreakpointValue({ base: true, sm: false });
const avatar = useMemo(
() => (
<Avatar
borderColor={borderColor}
size={inlineAvatar ? "xs" : "sm"}
mr={inlineAvatar ? 2 : 0}
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
/>
),
[borderColor, inlineAvatar, item.is_assistant]
);
return (
<FlaggableElement message={item}>
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
<Box borderRadius="full" border="solid" borderWidth="1px" borderColor={borderColor} bg={avatarColor}>
<Avatar
size="sm"
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
/>
{!inlineAvatar && avatar}
<Box
width={["full", "full", "full", "fit-content"]}
maxWidth={["full", "full", "full", "2xl"]}
p="4"
borderRadius="md"
bg={item.is_assistant ? backgroundColor : backgroundColor2}
onClick={props.enabled && goToMessage}
_hover={props.enabled && { cursor: "pointer", opacity: 0.9 }}
>
{inlineAvatar && avatar}
{item.text}
</Box>
{props.enabled ? (
<Box width={["full", "full", "full", "fit-content"]} maxWidth={["full", "full", "full", "2xl"]}>
<Link href={`/messages/${item.id}`}>
<LinkBox
bg={item.is_assistant ? backgroundColor : backgroundColor2}
p="4"
borderRadius="md"
whiteSpace="pre-line"
>
{item.text}
</LinkBox>
</Link>
</Box>
) : (
<Box
width={["full", "full", "full", "fit-content"]}
maxWidth={["full", "full", "full", "2xl"]}
bg={item.is_assistant ? backgroundColor : backgroundColor2}
p="4"
borderRadius="md"
>
{item.text}
</Box>
)}
</HStack>
</FlaggableElement>
);
+5 -8
View File
@@ -1,22 +1,19 @@
import { Box, BoxProps, useColorModeValue } from "@chakra-ui/react";
import clsx from "clsx";
import { PropsWithChildren } from "react";
interface SurveyCardProps {
className?: string;
children: React.ReactNode;
}
export const SurveyCard = (props: SurveyCardProps) => {
export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => {
const backgroundColor = useColorModeValue("white", "gray.700");
const BoxClasses: BoxProps = {
gap: "2",
borderRadius: "xl",
shadow: "base",
className: "p-4 sm:p-6",
className: clsx("p-4 sm:p-6", props.className),
};
return (
<Box bg={backgroundColor} {...BoxClasses}>
<Box as="section" bg={backgroundColor} {...BoxClasses}>
{props.children}
</Box>
);
+10 -3
View File
@@ -1,5 +1,5 @@
export enum TaskCategory {
Tasks = "Tasks",
Random = "Random",
Create = "Create",
Evaluate = "Evaluate",
Label = "Label",
@@ -20,12 +20,19 @@ export interface TaskInfo {
unchanged_message?: string;
}
export const TaskCategoryLabels: { [key in TaskCategory]: string } = {
[TaskCategory.Random]: "I'm feeling lucky",
[TaskCategory.Create]: "Create",
[TaskCategory.Evaluate]: "Evaluate",
[TaskCategory.Label]: "Label",
};
export const TaskTypes: TaskInfo[] = [
// general/random
{
label: "Start a Task",
desc: "Help us improve Open Assistant by starting a random task.",
category: TaskCategory.Tasks,
category: TaskCategory.Random,
pathname: "/tasks/random",
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
type: "random",
@@ -121,7 +128,7 @@ export const TaskTypes: TaskInfo[] = [
category: TaskCategory.Label,
pathname: "/label/label_prompter_reply",
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_prompter_reply",
overview: "Given the following discussion, provide labels for the final prompt",
overview: "Given the following discussion, provide labels for the final prompt.",
type: "label_prompter_reply",
mode: "full",
update_type: "text_labels",
@@ -1,23 +1,22 @@
import { useState } from "react";
import { get, post } from "src/lib/api";
import { BaseTask, TaskResponse } from "src/types/Task";
import { BaseTask, TaskResponse, TaskType as TaskTypeEnum } from "src/types/Task";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskApiEndpoint: string) => {
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskType: TaskTypeEnum) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>("/api/new_task/" + taskApiEndpoint, get, {
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>("/api/new_task/" + taskType, get, {
onSuccess: (data) => setTasks([data]),
revalidateOnMount: true,
dedupingInterval: 500,
});
const { trigger } = useSWRMutation("/api/update_task", post, {
onSuccess: async (response) => {
const newTask: ConcreteTaskResponse = response;
onSuccess: async (newTask: ConcreteTaskResponse) => {
setTasks((oldTasks) => [...oldTasks, newTask]);
mutate();
},
+7
View File
@@ -0,0 +1,7 @@
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
export const getDefaultStaticProps = async ({ locale }) => ({
props: {
...(await serverSideTranslations(locale)),
},
});
+13 -14
View File
@@ -1,7 +1,7 @@
import { JWT } from "next-auth/jwt";
import type { Message } from "src/types/Conversation";
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import type { BackendUser } from "src/types/Users";
import type { AvailableTasks } from "src/types/Task";
import type { BackendUser, BackendUserCore } from "src/types/Users";
export class OasstError {
message: string;
@@ -108,14 +108,10 @@ export class OasstApiClient {
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
// This is a raw Json type, so we can't use it to strongly type the task.
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
async fetchTask(taskType: string, user: BackendUserCore): Promise<any> {
return this.post("/api/v1/tasks/", {
type: taskType,
user: {
id: userToken.sub,
display_name: userToken.name,
auth_method: "local",
},
user,
});
}
@@ -140,15 +136,11 @@ export class OasstApiClient {
messageId: string,
userMessageId: string,
content: object,
userToken: JWT
user: BackendUserCore
): Promise<any> {
return this.post("/api/v1/tasks/interaction", {
type: updateType,
user: {
id: userToken.sub,
display_name: userToken.name,
auth_method: "local",
},
user,
task_id: taskId,
message_id: messageId,
user_message_id: userMessageId,
@@ -224,6 +216,13 @@ export class OasstApiClient {
async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise<LeaderboardReply> {
return this.get(`/api/v1/leaderboards/${time_frame}`);
}
/**
* Returns the counts of all tasks (some might be zero)
*/
async fetch_available_tasks(user: BackendUserCore): Promise<AvailableTasks> {
return this.post(`/api/v1/tasks/availability`, user);
}
}
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
+38
View File
@@ -0,0 +1,38 @@
import prisma from "src/lib/prismadb";
import type { BackendUserCore } from "src/types/Users";
/**
* Returns a `BackendUserCore` that can be used for interacting with the Backend service.
*
* @param {string} id The user's web auth id.
*
* @return {BackendUserCore} The most specific auth type and id for the user.
*/
const getBackendUserCore = async (id: string) => {
const user = await prisma.user.findUnique({
where: { id },
select: {
id: true,
name: true,
accounts: true,
},
});
// If there are no linked accounts, just use what we have locally.
if (user.accounts.length === 0) {
return {
id: user.id,
display_name: user.name,
auth_method: "local",
} as BackendUserCore;
}
// Otherwise, use the first linked account that the user created.
return {
id: user.accounts[0].providerAccountId,
display_name: user.name,
auth_method: user.accounts[0].provider,
} as BackendUserCore;
};
export { getBackendUserCore };
+1
View File
@@ -3,6 +3,7 @@ import Head from "next/head";
import { FiAlertTriangle } from "react-icons/fi";
import { EmptyState } from "src/components/EmptyState";
import { getTransparentHeaderLayout } from "src/components/Layout";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
function Error() {
return (
+1
View File
@@ -3,6 +3,7 @@ import Head from "next/head";
import { FiAlertTriangle } from "react-icons/fi";
import { EmptyState } from "src/components/EmptyState";
import { getTransparentHeaderLayout } from "src/components/Layout";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
function ServerError() {
return (
+1
View File
@@ -4,6 +4,7 @@ import { Container } from "src/components/Container";
import Roadmap from "src/components/Roadmap";
import Services from "src/components/Services";
import Vision from "src/components/Vision";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const AboutPage = () => {
return (
+1
View File
@@ -4,6 +4,7 @@ import Router from "next/router";
import { useSession } from "next-auth/react";
import React from "react";
import { Control, useForm, useWatch } from "react-hook-form";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
export default function Account() {
const { data: session } = useSession();
+26 -10
View File
@@ -1,8 +1,11 @@
import { Button } from "@chakra-ui/react";
import { Button, Divider, Flex, Grid, Icon, Text } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { useSession } from "next-auth/react";
import React from "react";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
import { MdOutlineEdit } from "react-icons/md";
import { SurveyCard } from "src/components/Survey/SurveyCard";
export default function Account() {
const { data: session } = useSession();
@@ -19,15 +22,28 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<div className="oa-basic-theme">
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
</div>
<main className="oa-basic-theme p-6">
<Flex m="auto" className="max-w-7xl" alignContent="center">
<SurveyCard className="w-full">
<Text as="b" display="block" fontSize="2xl" py={2}>
Your Account
</Text>
<Divider />
<Grid gridTemplateColumns="repeat(2, max-content)" alignItems="center" gap={6} py={4}>
<Text as="b">Username</Text>
<Flex gap={2}>
{session.user.name ?? "(No username)"}
<Link href="/account/edit">
<Icon boxSize={5} as={MdOutlineEdit} />
</Link>
</Flex>
<Text as="b">Email</Text>
<Text>{session.user.email ?? "(No Email)"}</Text>
</Grid>
<p></p>
</SurveyCard>
</Flex>
</main>
</>
);
}
+1
View File
@@ -4,6 +4,7 @@ import { useSession } from "next-auth/react";
import { useEffect } from "react";
import { getAdminLayout } from "src/components/Layout";
import { UserTable } from "src/components/UserTable";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
/**
* Provides the admin index page that will display a list of users and give
+3 -1
View File
@@ -3,6 +3,7 @@ import { InferGetServerSidePropsType } from "next";
import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
import { useEffect } from "react";
import { useForm } from "react-hook-form";
import { getAdminLayout } from "src/components/Layout";
@@ -111,7 +112,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType<typeof getServerSidePr
/**
* Fetch the user's data on the server side when rendering.
*/
export async function getServerSideProps({ query }) {
export async function getServerSideProps({ query, locale }) {
const backend_user = await oasstApiClient.fetch_user(query.id);
const local_user = await prisma.user.findUnique({
where: { id: backend_user.id },
@@ -126,6 +127,7 @@ export async function getServerSideProps({ query }) {
return {
props: {
user,
...(await serverSideTranslations(locale, ["common"])),
},
};
}
+11
View File
@@ -0,0 +1,11 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import { getBackendUserCore } from "src/lib/users";
const handler = withoutRole("banned", async (req, res, token) => {
const user = await getBackendUserCore(token.sub);
const availableTasks = await oasstApiClient.fetch_available_tasks(user);
res.status(200).json(availableTasks);
});
export default handler;
@@ -1,6 +1,7 @@
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
import { getBackendUserCore } from "src/lib/users";
/**
* Returns a new task created from the Task Backend. We do a few things here:
@@ -14,9 +15,10 @@ const handler = withoutRole("banned", async (req, res, token) => {
// Fetch the new task.
const { task_type } = req.query;
const user = await getBackendUserCore(token.sub);
let task;
try {
task = await oasstApiClient.fetchTask(task_type as string, token);
task = await oasstApiClient.fetchTask(task_type as string, user);
} catch (err) {
console.error(err);
res.status(500).json(err);
+3 -1
View File
@@ -2,6 +2,7 @@ import { Prisma } from "@prisma/client";
import { withoutRole } from "src/lib/auth";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
import { getBackendUserCore } from "src/lib/users";
/**
* Stores the task interaction with the Task Backend and then returns the next task generated.
@@ -39,9 +40,10 @@ const handler = withoutRole("banned", async (req, res, token) => {
},
});
const user = await getBackendUserCore(token.sub);
let newTask;
try {
newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, token);
newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, user);
} catch (err) {
console.error(JSON.stringify(err));
return res.status(500).json(err);
+4 -3
View File
@@ -5,6 +5,7 @@ import Head from "next/head";
import Link from "next/link";
import { useRouter } from "next/router";
import { ClientSafeProvider, getProviders, signIn } from "next-auth/react";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
import React, { useEffect, useRef, useState } from "react";
import { useForm } from "react-hook-form";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
@@ -47,7 +48,6 @@ interface SigninProps {
function Signin({ providers }: SigninProps) {
const router = useRouter();
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
const [error, setError] = useState("");
useEffect(() => {
@@ -151,7 +151,7 @@ function Signin({ providers }: SigninProps) {
Signin.getLayout = (page) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
<Header />
{page}
<Footer />
</div>
@@ -209,11 +209,12 @@ const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSaf
);
};
export const getServerSideProps: GetServerSideProps<SigninProps> = async () => {
export const getServerSideProps: GetServerSideProps<SigninProps> = async ({ locale }) => {
const providers = await getProviders();
return {
props: {
providers,
...(await serverSideTranslations(locale, ["common"])),
},
};
};
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const AssistantReply = () => {
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const InitialPrompt = () => {
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
+1
View File
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const UserReply = () => {
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
+18 -2
View File
@@ -1,10 +1,20 @@
import { Flex } from "@chakra-ui/react";
import Head from "next/head";
import { useMemo } from "react";
import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard";
import { getDashboardLayout } from "src/components/Layout";
import { TaskCategory } from "src/components/Tasks/TaskTypes";
import { get } from "src/lib/api";
import type { AvailableTasks, TaskType } from "src/types/Task";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
import useSWRImmutable from "swr/immutable";
const Dashboard = () => {
const { data } = useSWRImmutable<AvailableTasks>("/api/available_tasks", get);
// TODO: show only these tasks:
const availableTasks = useMemo(() => filterAvailableTasks(data ?? {}), [data]);
return (
<>
<Head>
@@ -13,13 +23,19 @@ const Dashboard = () => {
</Head>
<Flex direction="column" gap="10">
<WelcomeCard />
<TaskOption displayTaskCategories={[TaskCategory.Tasks]} />
<TaskOption displayTaskCategories={[TaskCategory.Random]} />
<LeaderboardTable />
</Flex>
</>
);
};
Dashboard.getLayout = (page) => getDashboardLayout(page);
Dashboard.getLayout = getDashboardLayout;
export default Dashboard;
const filterAvailableTasks = (availableTasks: Partial<AvailableTasks>) =>
Object.entries(availableTasks)
.filter(([_, count]) => count > 0)
.sort((a, b) => b[1] - a[1])
.map(([taskType]) => taskType) as TaskType[];
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const RankAssistantReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const RankInitialPrompts = () => {
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const RankUserReplies = () => {
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const LabelAssistantReply = () => {
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const LabelInitialPrompt = () => {
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const LabelPrompterReply = () => {
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
+1
View File
@@ -2,6 +2,7 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-u
import Head from "next/head";
import { getDashboardLayout } from "src/components/Layout";
import { LeaderboardGridCell } from "src/components/LeaderboardGridCell";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
const Leaderboard = () => {
+9 -5
View File
@@ -1,5 +1,6 @@
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
import { getDashboardLayout } from "src/components/Layout";
import { MessageLoading } from "src/components/Loading/MessageLoading";
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
@@ -48,10 +49,13 @@ const MessageDetail = ({ id }: { id: string }) => {
);
};
MessageDetail.getInitialProps = async ({ query }) => {
const { id } = query;
return { id };
};
MessageDetail.getLayout = (page) => getDashboardLayout(page);
export const getServerSideProps = async ({ locale, query }) => ({
props: {
id: query.id,
...(await serverSideTranslations(locale, ["common"])),
},
});
export default MessageDetail;
+1
View File
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
import { MessageTable } from "src/components/Messages/MessageTable";
import { get } from "src/lib/api";
import useSWRImmutable from "swr/immutable";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const MessagesDashboard = () => {
const boxBgColor = useColorModeValue("white", "gray.800");
+1
View File
@@ -3,6 +3,7 @@ import Head from "next/head";
import { getTransparentHeaderLayout } from "src/components/Layout";
import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard";
import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const PrivacyPolicy = () => {
const backgroundColor = useColorModeValue("gray.100", "gray.800");
+2 -1
View File
@@ -4,9 +4,10 @@ import { getDashboardLayout } from "src/components/Layout";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
import { TaskType } from "src/types/Task";
const RandomTask = () => {
const { tasks, isLoading, trigger, reset } = useGenericTaskAPI("random");
const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random);
if (isLoading) {
return <LoadingScreen text="Loading..." />;
+1
View File
@@ -3,6 +3,7 @@ import Head from "next/head";
import { getTransparentHeaderLayout } from "src/components/Layout";
import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard";
import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard";
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
const TermsOfService = () => {
const TermsData = [
+4
View File
@@ -10,6 +10,8 @@ export const enum TaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
random = "random",
}
// we need to reconsider how to handle task content types
@@ -32,3 +34,5 @@ export interface TaskResponse<Task extends BaseTask> {
userId: string;
task: Task;
}
export type AvailableTasks = { [taskType in TaskType]: number };
+6 -4
View File
@@ -1,7 +1,4 @@
/**
* Reports the Backend's knowledge of a user.
*/
export interface BackendUser {
export interface BackendUserCore {
/**
* The user's unique ID according to the `auth_method`.
*/
@@ -18,7 +15,12 @@ export interface BackendUser {
* - local
*/
auth_method: string;
}
/**
* Reports the Backend's knowledge of a user.
*/
export interface BackendUser extends BackendUserCore {
/**
* The backend's UUID for this user.
*/