mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
Merge branch 'main' into 766_admin_enhancement
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""add ix_user_display_name_id
|
||||
|
||||
Revision ID: 4f26fec4d204
|
||||
Revises: 0964ac95170d
|
||||
Create Date: 2023-01-19 22:00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4f26fec4d204"
|
||||
down_revision = "7f0a28a156f4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index("ix_user_display_name_id", "user", ["display_name", "id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_user_display_name_id", table_name="user")
|
||||
# ### end Alembic commands ###
|
||||
@@ -15,34 +15,29 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=list[protocol.FrontEndUser])
|
||||
def get_users(
|
||||
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
|
||||
def get_users_ordered_by_username(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/by_display_name")
|
||||
def query_frontend_users_by_display_name(
|
||||
search_text: str,
|
||||
exact: bool = False,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(20, gt=0, le=1000),
|
||||
auth_method: str = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_by_display_name(
|
||||
search_text=search_text, exact=exact, api_client_id=api_client_id, limit=max_count, auth_method=auth_method
|
||||
users = ur.query_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
gt_id=gt_id,
|
||||
lte_username=lte_username,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@@ -16,7 +16,61 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=protocol.FrontEndUser)
|
||||
@router.get("/by_username", response_model=list[protocol.FrontEndUser])
|
||||
def get_users_ordered_by_username(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
search_text: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_ordered_by_username(
|
||||
api_client_id=api_client_id,
|
||||
gte_username=gte_username,
|
||||
gt_id=gt_id,
|
||||
lte_username=lte_username,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/by_display_name", response_model=list[protocol.FrontEndUser])
|
||||
def get_users_ordered_by_display_name(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
gte_display_name: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_display_name: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users_ordered_by_display_name(
|
||||
api_client_id=api_client_id,
|
||||
gte_display_name=gte_display_name,
|
||||
gt_id=gt_id,
|
||||
lte_display_name=lte_display_name,
|
||||
lt_id=lt_id,
|
||||
auth_method=auth_method,
|
||||
search_text=search_text,
|
||||
limit=max_count,
|
||||
)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=protocol.FrontEndUser)
|
||||
def get_user(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
@@ -31,7 +85,7 @@ def get_user(
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@router.put("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
@@ -46,7 +100,7 @@ def update_user(
|
||||
ur.update_user(user_id, enabled, notes)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -10,7 +10,10 @@ from sqlmodel import AutoString, Field, Index, SQLModel
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "user"
|
||||
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
__table_args__ = (
|
||||
Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),
|
||||
Index("ix_user_display_name_id", "display_name", "id", unique=True),
|
||||
)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
|
||||
@@ -5,7 +5,7 @@ from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import Session, and_, or_
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -135,13 +135,16 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
def query_users(
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = 20,
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
gte_username: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_username: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -150,34 +153,52 @@ class UserRepository:
|
||||
if api_client_id != self.api_client.id:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
users = self.db.query(User)
|
||||
qry = self.db.query(User).order_by(User.username, User.id)
|
||||
|
||||
if api_client_id:
|
||||
users = users.filter(User.api_client_id == api_client_id)
|
||||
if gte_username is not None:
|
||||
if gt_id:
|
||||
qry = qry.filter(
|
||||
or_(User.username > gte_username, and_(User.username == gte_username, User.id > gt_id))
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.username >= gte_username)
|
||||
elif gt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if lte_username is not None:
|
||||
if lt_id:
|
||||
qry = qry.filter(
|
||||
or_(User.username < lte_username, and_(User.username == lte_username, User.id < lt_id))
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.username <= lte_username)
|
||||
elif lt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if auth_method:
|
||||
users = users.filter(User.auth_method == auth_method)
|
||||
qry = qry.filter(User.auth_method == auth_method)
|
||||
if api_client_id:
|
||||
qry = qry.filter(User.api_client_id == api_client_id)
|
||||
|
||||
users = users.order_by(User.id)
|
||||
|
||||
if gt:
|
||||
users = users.filter(User.id > gt)
|
||||
|
||||
if lt:
|
||||
users = users.filter(User.id < lt).order_by(None).order_by(User.id.desc())
|
||||
if search_text:
|
||||
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
|
||||
qry = qry.filter(User.username.like(pattern))
|
||||
|
||||
if limit is not None:
|
||||
users = users.limit(limit)
|
||||
qry = qry.limit(limit)
|
||||
|
||||
return users.all()
|
||||
return qry.all()
|
||||
|
||||
def query_users_by_display_name(
|
||||
def query_users_ordered_by_display_name(
|
||||
self,
|
||||
search_text: str,
|
||||
exact: Optional[bool] = False,
|
||||
limit: Optional[int] = 20,
|
||||
gte_display_name: Optional[str] = None,
|
||||
gt_id: Optional[UUID] = None,
|
||||
lte_display_name: Optional[str] = None,
|
||||
lt_id: Optional[UUID] = None,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
search_text: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
) -> list[User]:
|
||||
if not self.api_client.trusted:
|
||||
if not api_client_id:
|
||||
@@ -186,11 +207,40 @@ class UserRepository:
|
||||
if api_client_id != self.api_client.id:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
qry = self.db.query(User).order_by(User.display_name)
|
||||
qry = self.db.query(User).order_by(User.display_name, User.id)
|
||||
|
||||
if exact:
|
||||
qry = qry.filter(User.display_name == search_text)
|
||||
else:
|
||||
if gte_display_name is not None:
|
||||
if gt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
User.display_name > gte_display_name,
|
||||
and_(User.display_name == gte_display_name, User.id > gt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.display_name >= gte_display_name)
|
||||
elif gt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if lte_display_name is not None:
|
||||
if lt_id:
|
||||
qry = qry.filter(
|
||||
or_(
|
||||
User.display_name < lte_display_name,
|
||||
and_(User.display_name == lte_display_name, User.id < lt_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
qry = qry.filter(User.display_name <= lte_display_name)
|
||||
elif lt_id:
|
||||
raise OasstError("Need id and name for keyset pagination", OasstErrorCode.GENERIC_ERROR)
|
||||
|
||||
if auth_method:
|
||||
qry = qry.filter(User.auth_method == auth_method)
|
||||
if api_client_id:
|
||||
qry = qry.filter(User.api_client_id == api_client_id)
|
||||
|
||||
if search_text:
|
||||
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
|
||||
qry = qry.filter(User.display_name.like(pattern))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base
|
||||
learning_rate: 1e-5
|
||||
scheduler: cosine
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
|
||||
@@ -60,6 +60,26 @@ python trainer.py --configs defaults your-model-name --deepspeed
|
||||
|
||||
## Dataset choices
|
||||
|
||||
To specify which translation pair for
|
||||
[WMT](https://huggingface.co/datasets/wmt19) and
|
||||
[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply
|
||||
add the supported language pair at the postfix
|
||||
|
||||
```
|
||||
datasets:
|
||||
- wmt2019_zh-en
|
||||
- wmt2019_ru-en
|
||||
- wmt2019_de-en
|
||||
- ted_trans_nl-en
|
||||
- ted_trans_de-ja
|
||||
```
|
||||
|
||||
Currently only these languages are supported via prompt translation:
|
||||
|
||||
```
|
||||
ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Experimental results in wandb
|
||||
|
||||
@@ -29,6 +29,14 @@ defaults:
|
||||
- soda
|
||||
- joke
|
||||
- gsm8k
|
||||
- dive_mt
|
||||
- wmt2019_zh-en
|
||||
- wmt2019_ru-en
|
||||
- wmt2019_de-en
|
||||
- ted_trans_nl-en
|
||||
- ted_trans_de-ja
|
||||
- instruct_tuning
|
||||
- wmt2019_de-en
|
||||
- samsum
|
||||
- soda_dialogue
|
||||
cache_dir: .cache
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
# Dataset collections overview:
|
||||
|
||||
currently dataset can be divided into 3 classes
|
||||
|
||||
- language knowledge
|
||||
|
||||
- summarization
|
||||
|
||||
- translation
|
||||
|
||||
- dialogue : don't let user know you are a robot
|
||||
|
||||
- STEM : knowledge about the world
|
||||
|
||||
- coding
|
||||
|
||||
- world knowledge <= ideally we want to handle this via prefix context
|
||||
|
||||
Issues and TODO:
|
||||
|
||||
- as dataset are growing, how can we update this section less
|
||||
|
||||
- ideally we can update the config yaml and new dataset will be download from
|
||||
hub
|
||||
|
||||
- one possible idea is we upload the trasform format of these dataset to the
|
||||
OA hub
|
||||
@@ -1,11 +1,26 @@
|
||||
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
|
||||
"""
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"]
|
||||
SUMMARIZATION_DATASETS = [
|
||||
"xsum",
|
||||
"cnn_dailymail",
|
||||
"samsum",
|
||||
"multi_news",
|
||||
"scitldr",
|
||||
"billsum",
|
||||
"debate_sum",
|
||||
"tldr_news",
|
||||
]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
@@ -25,14 +40,34 @@ def get_one_dataset(conf, dataset_name):
|
||||
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif "wmt2019" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
train = WMT2019(pair=language_pair, split="train")
|
||||
eval = WMT2019(pair=language_pair, split="validation")
|
||||
elif dataset_name == "dive_mt":
|
||||
dataset = DiveMT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prosocial_dialogue":
|
||||
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
|
||||
elif dataset_name == "explain_prosocial":
|
||||
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
@@ -42,6 +77,9 @@ def get_one_dataset(conf, dataset_name):
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class DialogueDataCollator:
|
||||
|
||||
for feature_one in features:
|
||||
assert len(feature_one) % 2 == 0, "Number of messages must be even"
|
||||
# TODO: we should push this to dataset __getitem__
|
||||
messages = [
|
||||
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
|
||||
+ x
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
@@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset):
|
||||
we are ignoring results with multiple lines for now
|
||||
"""
|
||||
|
||||
name = "prompt_dialogue"
|
||||
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
@@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
class InstructionTuning(Dataset):
|
||||
"""
|
||||
We have seen some promising capabilities from instruction tuning
|
||||
with the following mix of datasets that are derived from datasets
|
||||
available online.
|
||||
The files for this data are in json format as a list of tuples
|
||||
where each tuple is (source,instruction_response_pair)
|
||||
|
||||
- instruction_tuning_dataset_alpha_part1.json
|
||||
- instruction_tuning_dataset_alpha_part2.json
|
||||
|
||||
Not to be confused with unatural instruction
|
||||
"""
|
||||
|
||||
name = "instruction_dataset"
|
||||
url_part_2 = (
|
||||
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json"
|
||||
)
|
||||
url_part_1 = (
|
||||
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json"
|
||||
)
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
self.pairs = []
|
||||
for file_link in [self.url_part_1, self.url_part_2]:
|
||||
basename = file_link.split("/")[-1]
|
||||
instruction_tune_file = os.path.join(cache_dir, basename)
|
||||
if not os.path.exists(instruction_tune_file):
|
||||
with urlopen(file_link) as file:
|
||||
content = file.read().decode()
|
||||
with open(instruction_tune_file, "w", encoding="utf-8") as fout:
|
||||
fout.write(content)
|
||||
|
||||
with open(instruction_tune_file, "r", encoding="utf-8") as f:
|
||||
datasets = json.load(f)
|
||||
for row in datasets:
|
||||
_, response_pair = row
|
||||
question, answer = response_pair.split("\n\n", maxsplit=1)
|
||||
answer = answer.replace("<|endoftext|>", "").strip()
|
||||
self.pairs.append((question, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"""
|
||||
Open / close book QA datasets
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
@@ -75,6 +82,9 @@ class QADataset(Dataset):
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
|
||||
name = "webgpt"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -89,7 +99,9 @@ class WebGPT(Dataset):
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
# only keep the best answer
|
||||
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
questions[question] = re_reference_remove.sub(
|
||||
"", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
)
|
||||
|
||||
self.questions = questions
|
||||
|
||||
@@ -103,6 +115,9 @@ class WebGPT(Dataset):
|
||||
|
||||
|
||||
class SODA(Dataset):
|
||||
|
||||
name = "soda"
|
||||
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
@@ -207,8 +222,8 @@ class SODADialogue(Dataset):
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
""" """
|
||||
|
||||
name = "joke"
|
||||
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
@@ -240,3 +255,6 @@ class JokeExplaination(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/aquamuse
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""
|
||||
Summarize different spectrum of documents
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
@@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = {
|
||||
}
|
||||
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
"cnn_dailymail": (
|
||||
"cnn_dailymail",
|
||||
"3.0.0",
|
||||
),
|
||||
"samsum": ("samsum",),
|
||||
"xsum": ("xsum",),
|
||||
"multi_news": ("multi_news",),
|
||||
"scitldr": (
|
||||
"scitldr",
|
||||
"AIC",
|
||||
),
|
||||
"billsum": ("billsum",),
|
||||
"reddit": ("reddit",),
|
||||
"tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news
|
||||
"debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum
|
||||
}
|
||||
|
||||
summarization_name_mapping = {
|
||||
@@ -29,6 +40,8 @@ summarization_name_mapping = {
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
"tldr_news": ("content", "headline"),
|
||||
"debate_sum": ("Full-Document", "Extract"),
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +56,7 @@ def index_summary_merge(text, summary):
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.name = dataset
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
SFT dataset to reject toxic questions
|
||||
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProsocialDialogueExplaination(Dataset):
|
||||
name = "prosocial_explain"
|
||||
TEMPLATE = [
|
||||
# 0 : reply or sentence of interest, 1 : reason of caution
|
||||
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
|
||||
("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"),
|
||||
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
|
||||
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
|
||||
]
|
||||
|
||||
def __init__(self, split="train", cache_dir=".cache") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]):
|
||||
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
|
||||
self.pairs.append(
|
||||
(
|
||||
prompt_template.format(row["context"], safety_annotation),
|
||||
answer_template.format(safe_answer, safety_annotation),
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
|
||||
|
||||
class ProsocialDialogue(Dataset):
|
||||
name = "prosocial_dialogue"
|
||||
"""
|
||||
ProsocialDialog, we set up a human-AI collaborative data creation framework,
|
||||
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
|
||||
provide prosocial responses to them. This approach allows us to circumvent
|
||||
two substantial challenges:
|
||||
(1) there are no available large-scale corpora of multiturn prosocial conversations
|
||||
between humans
|
||||
(2) asking humans to write unethical, toxic, or problematic utterances could result
|
||||
in psychological harms (Roberts, 2017; Steiger et al., 2021).
|
||||
"""
|
||||
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
|
||||
|
||||
def __init__(self, split="train", cache_dir=".cache") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for answer in row["rots"]:
|
||||
self.pairs.append((self.PREFIX + row["context"], answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
List of translation dataset
|
||||
|
||||
GroNLP/divemt
|
||||
|
||||
fill in the blanks : https://huggingface.co/datasets/m_lama
|
||||
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# postfix prompt
|
||||
TRANSLATION_PROMPT = {
|
||||
"zh": [ # simplified or any chinese which was not mentioned
|
||||
"Translate to chinese simplified: {}",
|
||||
"{}, translate to chinese",
|
||||
"{} give me the chinese translation",
|
||||
"翻译成中文: {}",
|
||||
"{} 这句中文翻译怎麽写?",
|
||||
"我需要这句话的中文翻译: {}",
|
||||
],
|
||||
"zh-tw": [ # WMT code
|
||||
"{}. Translate to chinese traditional",
|
||||
"{}, translate to chinese",
|
||||
"{}. get chinese translation",
|
||||
"中文翻譯: {}",
|
||||
"幫我翻譯成中文: '{}'",
|
||||
"{} 這句中文翻譯怎麼寫?",
|
||||
],
|
||||
"ja": [
|
||||
"{}: help me translate to japanese",
|
||||
"Need japanese translation: {}",
|
||||
"{}: にほんごやくをよこす",
|
||||
"{}: にほんごやくをおくれ",
|
||||
"{}: にほんごやくを じょす",
|
||||
"give me the japanese translation, {}",
|
||||
],
|
||||
"de": [
|
||||
"{}: translate to german",
|
||||
"give me the german translation {}",
|
||||
"I want german translation {}",
|
||||
"{}, ins Deutsche übersetzen",
|
||||
"{}, Übersetzen ins Deutsche",
|
||||
],
|
||||
"fr": [
|
||||
"{}. translate to french",
|
||||
"{} write in french",
|
||||
"{} french translation",
|
||||
"{} ,donnez moi la traduction française",
|
||||
],
|
||||
"ko": [
|
||||
"{}. translate to Korean",
|
||||
"how do we write in korean: {}",
|
||||
"give me the korean translation: {}",
|
||||
"{}, 한국어 번역을 해주세요",
|
||||
],
|
||||
"ms": [
|
||||
"{} translate to malay",
|
||||
"{} how do we write in Malay",
|
||||
"{} give me the malay translation",
|
||||
"{} , berikan saya terjemahan dalam bahasa melayu",
|
||||
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
|
||||
],
|
||||
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
|
||||
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
|
||||
"tr": ["{}. türkçeye çevi̇ri̇n", "{} write in turkish", "turkish translation: '{}'", "türkçeye çevi̇rmek: {}"],
|
||||
"it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"],
|
||||
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
|
||||
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
|
||||
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
|
||||
}
|
||||
|
||||
|
||||
class TranslationPair(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pairs = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
def __init__(self, pair="zh-en", split="train") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("wmt19", pair)[split]
|
||||
self.pairs = []
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
row = row["translation"]
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
|
||||
class DiveMT(TranslationPair):
|
||||
|
||||
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
|
||||
|
||||
def __init__(self, split="train") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("GroNLP/divemt", "main")[split]
|
||||
tgt, src = "tgt_text", "src_text"
|
||||
for row in dataset:
|
||||
# ISO 639-2
|
||||
lang_code_2 = row["subject_id"].split("_")[0]
|
||||
lang_code = self.REMAP[lang_code_2]
|
||||
if lang_code not in TRANSLATION_PROMPT:
|
||||
continue
|
||||
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else: # translating in reverse direction
|
||||
lang_code = "en"
|
||||
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
|
||||
class TEDTalk(TranslationPair):
|
||||
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
|
||||
|
||||
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
row = row["translation"]
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
@@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
def test_all_datasets():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke"]
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
|
||||
config = Namespace(cache_dir=".cache")
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
for dataset_name in translation + others + summarize_base + qa_base:
|
||||
print(dataset_name)
|
||||
train, eval = get_one_dataset(config, dataset_name)
|
||||
# sanity check
|
||||
@@ -48,7 +49,3 @@ def test_collate_fn():
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_collate_fn()
|
||||
|
||||
@@ -1,34 +1,27 @@
|
||||
import { OasstApiClient, OasstError } from "src/lib/oasst_api_client";
|
||||
import type { BackendUserCore } from "src/types/Users";
|
||||
|
||||
describe("Contract test for Oasst API", function () {
|
||||
// Assumes this is running the mock server.
|
||||
const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
|
||||
|
||||
const testUser = {
|
||||
id: "abcd",
|
||||
display_name: "test",
|
||||
auth_method: "local",
|
||||
} as BackendUserCore;
|
||||
|
||||
it("can fetch a task", async () => {
|
||||
expect(
|
||||
await oasstApiClient.fetchTask("random", {
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
})
|
||||
).to.be.not.null;
|
||||
expect(await oasstApiClient.fetchTask("random", testUser)).to.be.not.null;
|
||||
});
|
||||
|
||||
it("can ack a task", async () => {
|
||||
const task = await oasstApiClient.fetchTask("random", {
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
});
|
||||
const task = await oasstApiClient.fetchTask("random", testUser);
|
||||
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
|
||||
});
|
||||
|
||||
it("can record a taskInteraction", async () => {
|
||||
const task = await oasstApiClient.fetchTask("random", {
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
});
|
||||
const task = await oasstApiClient.fetchTask("random", testUser);
|
||||
expect(
|
||||
await oasstApiClient.interactTask(
|
||||
"text_reply_to_message",
|
||||
@@ -36,11 +29,7 @@ describe("Contract test for Oasst API", function () {
|
||||
"321",
|
||||
"1",
|
||||
{ text: "Test" },
|
||||
{
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
}
|
||||
testUser
|
||||
)
|
||||
).to.be.not.null;
|
||||
});
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
{
|
||||
"about": "About",
|
||||
"account_settings": "Account",
|
||||
"connect": "Connect",
|
||||
"conversational": "Conversational AI for everyone.",
|
||||
"dashboard": "Dashboard",
|
||||
"discord": "Discord",
|
||||
"github": "GitHub"
|
||||
"docs": "Docs",
|
||||
"github": "GitHub",
|
||||
"legal": "Legal",
|
||||
"privacy_policy": "Privacy Policy",
|
||||
"report_a_bug": "Report a Bug",
|
||||
"sign_in": "Sign In",
|
||||
"sign_out": "Sign Out",
|
||||
"terms_of_service": "Terms of Service",
|
||||
"title": "Open Assistant"
|
||||
}
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
|
||||
import { TaskTypes } from "../Tasks/TaskTypes";
|
||||
import { TaskCategory, TaskCategoryLabels, TaskTypes } from "../Tasks/TaskTypes";
|
||||
|
||||
export const TaskOption = ({ displayTaskCategories }) => {
|
||||
export const TaskOption = ({ displayTaskCategories }: { displayTaskCategories: TaskCategory[] }) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
|
||||
return (
|
||||
<Box className="flex flex-col gap-14">
|
||||
{displayTaskCategories.map((category, categoryIndex) => (
|
||||
<div key={categoryIndex}>
|
||||
<Text className="text-2xl font-bold pb-4">{category}</Text>
|
||||
{displayTaskCategories.map((category) => (
|
||||
<div key={category}>
|
||||
<Text className="text-2xl font-bold pb-4">{TaskCategoryLabels[category]}</Text>
|
||||
<SimpleGrid columns={[1, 1, 2, 2, 3, 4]} gap={4}>
|
||||
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
{TaskTypes.filter((task) => task.category === category).map((item) => (
|
||||
<Link key={category + item.label} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
borderRadius="xl"
|
||||
|
||||
@@ -25,8 +25,8 @@ import clsx from "clsx";
|
||||
import { useEffect, useReducer } from "react";
|
||||
import { FiAlertCircle } from "react-icons/fi";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWR from "swr";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
@@ -114,9 +114,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
}, [data, isLoading]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/set_label", post, {
|
||||
onSuccess: () => {
|
||||
setIsEditing.off();
|
||||
},
|
||||
onSuccess: setIsEditing.off,
|
||||
});
|
||||
|
||||
const submitResponse = () => {
|
||||
@@ -149,7 +147,7 @@ export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
isLazy
|
||||
lazyBehavior="keepMounted"
|
||||
>
|
||||
<Box display="flex" alignItems="center" gap="2">
|
||||
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
|
||||
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { Box, Divider, Flex, Text, useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useMemo } from "react";
|
||||
|
||||
export function Footer() {
|
||||
const { t } = useTranslation();
|
||||
const { colorMode } = useColorMode();
|
||||
const backgroundColor = colorMode === "light" ? "white" : "gray.800";
|
||||
const textColor = colorMode === "light" ? "black" : "gray.300";
|
||||
@@ -33,10 +35,10 @@ export function Footer() {
|
||||
|
||||
<Box>
|
||||
<Text fontSize="md" fontWeight="bold">
|
||||
Open Assistant
|
||||
{t("title")}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
Conversational AI for everyone.
|
||||
{t("conversational")}
|
||||
</Text>
|
||||
</Box>
|
||||
</Flex>
|
||||
@@ -45,23 +47,23 @@ export function Footer() {
|
||||
<Box display="flex" flexDirection={["column", "row"]} gap={["6", "14"]} fontSize="sm">
|
||||
<Flex direction="column" alignItems={["center", "start"]}>
|
||||
<Text fontWeight="bold" color={textColor}>
|
||||
Legal
|
||||
{t("legal")}
|
||||
</Text>
|
||||
<FooterLink href="/privacy-policy" label="Privacy Policy" />
|
||||
<FooterLink href="/terms-of-service" label="Terms of Service" />
|
||||
<FooterLink href="/privacy-policy" label={t("privacy_policy")} />
|
||||
<FooterLink href="/terms-of-service" label={t("terms_of_service")} />
|
||||
</Flex>
|
||||
<Flex direction="column" alignItems={["center", "start"]}>
|
||||
<Text fontWeight="bold" color={textColor}>
|
||||
Connect
|
||||
{t("connect")}
|
||||
</Text>
|
||||
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
|
||||
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
|
||||
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label={t("github")} />
|
||||
<FooterLink href="https://ykilcher.com/open-assistant-discord" label={t("discord")} />
|
||||
</Flex>
|
||||
<Flex direction="column" alignItems={["center", "start"]}>
|
||||
<Text fontWeight="bold" color={textColor}>
|
||||
About
|
||||
{t("about")}
|
||||
</Text>
|
||||
<FooterLink href="https://projects.laion.ai/Open-Assistant" label="Docs" />
|
||||
<FooterLink href="https://projects.laion.ai/Open-Assistant" label={t("docs")} />
|
||||
</Flex>
|
||||
</Box>
|
||||
</nav>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { Box, Button, Text, Flex } from "@chakra-ui/react";
|
||||
import { Box, Button, Flex, Text } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { Flags } from "react-feature-flags";
|
||||
import { FaUser } from "react-icons/fa";
|
||||
|
||||
@@ -23,7 +24,8 @@ function AccountButton() {
|
||||
);
|
||||
}
|
||||
|
||||
export function Header(props) {
|
||||
export function Header() {
|
||||
const { t } = useTranslation();
|
||||
const { data: session } = useSession();
|
||||
const homeURL = session ? "/dashboard" : "/";
|
||||
|
||||
@@ -34,7 +36,7 @@ export function Header(props) {
|
||||
<Flex alignItems="center">
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
<Text fontFamily="inter" fontSize="2xl" fontWeight="bold" ml="3">
|
||||
Open Assistant
|
||||
{t("title")}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Link>
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
} from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { signOut, useSession } from "next-auth/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import React, { ElementType, useCallback } from "react";
|
||||
import { FiAlertTriangle, FiLayout, FiLogOut, FiSettings, FiShield } from "react-icons/fi";
|
||||
|
||||
@@ -25,6 +26,7 @@ interface MenuOption {
|
||||
}
|
||||
|
||||
export function UserMenu() {
|
||||
const { t } = useTranslation();
|
||||
const borderColor = useColorModeValue("gray.300", "gray.600");
|
||||
const handleSignOut = useCallback(() => {
|
||||
signOut({ callbackUrl: "/" });
|
||||
@@ -36,23 +38,23 @@ export function UserMenu() {
|
||||
}
|
||||
const options: MenuOption[] = [
|
||||
{
|
||||
name: "Dashboard",
|
||||
name: t("dashboard"),
|
||||
href: "/dashboard",
|
||||
desc: "Dashboard",
|
||||
desc: t("dashboard"),
|
||||
icon: FiLayout,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: "Account Settings",
|
||||
name: t("account_settings"),
|
||||
href: "/account",
|
||||
desc: "Account Settings",
|
||||
desc: t("account_settings"),
|
||||
icon: FiSettings,
|
||||
isExternal: false,
|
||||
},
|
||||
{
|
||||
name: "Report a Bug",
|
||||
name: t("report_a_bug"),
|
||||
href: "https://github.com/LAION-AI/Open-Assistant/issues/new/choose",
|
||||
desc: "Report a Bug",
|
||||
desc: t("report_a_bug"),
|
||||
icon: FiAlertTriangle,
|
||||
isExternal: true,
|
||||
},
|
||||
@@ -60,9 +62,9 @@ export function UserMenu() {
|
||||
|
||||
if (session.user.role === "admin") {
|
||||
options.unshift({
|
||||
name: "Admin Dashboard",
|
||||
name: t("admin_dashboard"),
|
||||
href: "/admin",
|
||||
desc: "Admin Dashboard",
|
||||
desc: t("admin_dashboard"),
|
||||
icon: FiShield,
|
||||
isExternal: false,
|
||||
});
|
||||
@@ -105,7 +107,7 @@ export function UserMenu() {
|
||||
<MenuDivider />
|
||||
<MenuItem gap="3" borderRadius="md" p="4" onClick={handleSignOut}>
|
||||
<FiLogOut className="text-blue-500" aria-hidden="true" />
|
||||
<Text>Sign Out</Text>
|
||||
<Text>{t("sign_out")}</Text>
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Box, Text, useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import { useTranslation } from "next-i18next";
|
||||
|
||||
import { Container } from "./Container";
|
||||
import { AnimatedCircles } from "./AnimatedCircles";
|
||||
import { Container } from "./Container";
|
||||
|
||||
export function Hero() {
|
||||
const { t } = useTranslation("index");
|
||||
|
||||
@@ -23,7 +23,7 @@ export const getDefaultLayout = (page: React.ReactElement) => (
|
||||
|
||||
export const getTransparentHeaderLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
<Header />
|
||||
{page}
|
||||
<Footer />
|
||||
</div>
|
||||
@@ -31,7 +31,7 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
|
||||
|
||||
export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
<Grid templateRows="min-content 1fr" h="full">
|
||||
<Header transparent={true} />
|
||||
<Header />
|
||||
<SideMenuLayout
|
||||
menuButtonOptions={[
|
||||
{
|
||||
@@ -66,7 +66,7 @@ export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
|
||||
export const getAdminLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
<Header />
|
||||
<SideMenuLayout
|
||||
menuButtonOptions={[
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@ interface MessageTableProps {
|
||||
|
||||
export function MessageTable({ messages, enableLink }: MessageTableProps) {
|
||||
return (
|
||||
<Stack spacing="3">
|
||||
<Stack spacing="4">
|
||||
{messages.map((item) => (
|
||||
<MessageTableEntry enabled={enableLink} item={item} key={item.id + item.frontend_message_id} />
|
||||
))}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Avatar, Box, HStack, LinkBox, useBreakpoint, useBreakpointValue, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import { FlaggableElement } from "src/components/FlaggableElement";
|
||||
import { Message } from "src/types/Conversation";
|
||||
|
||||
@@ -10,47 +12,48 @@ interface MessageTableEntryProps {
|
||||
}
|
||||
|
||||
export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
const router = useRouter();
|
||||
|
||||
const { item } = props;
|
||||
|
||||
const goToMessage = useCallback(() => router.push(`/messages/${item.id}`), [router, item.id]);
|
||||
|
||||
const backgroundColor = useColorModeValue("gray.100", "gray.700");
|
||||
const backgroundColor2 = useColorModeValue("#DFE8F1", "#42536B");
|
||||
|
||||
const avatarColor = useColorModeValue("white", "black");
|
||||
const borderColor = useColorModeValue("blackAlpha.200", "whiteAlpha.200");
|
||||
|
||||
const inlineAvatar = useBreakpointValue({ base: true, sm: false });
|
||||
|
||||
const avatar = useMemo(
|
||||
() => (
|
||||
<Avatar
|
||||
borderColor={borderColor}
|
||||
size={inlineAvatar ? "xs" : "sm"}
|
||||
mr={inlineAvatar ? 2 : 0}
|
||||
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
|
||||
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
|
||||
/>
|
||||
),
|
||||
[borderColor, inlineAvatar, item.is_assistant]
|
||||
);
|
||||
|
||||
return (
|
||||
<FlaggableElement message={item}>
|
||||
<HStack w={["full", "full", "full", "fit-content"]} gap={2}>
|
||||
<Box borderRadius="full" border="solid" borderWidth="1px" borderColor={borderColor} bg={avatarColor}>
|
||||
<Avatar
|
||||
size="sm"
|
||||
name={`${boolean(item.is_assistant) ? "Assistant" : "User"}`}
|
||||
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
|
||||
/>
|
||||
{!inlineAvatar && avatar}
|
||||
<Box
|
||||
width={["full", "full", "full", "fit-content"]}
|
||||
maxWidth={["full", "full", "full", "2xl"]}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
onClick={props.enabled && goToMessage}
|
||||
_hover={props.enabled && { cursor: "pointer", opacity: 0.9 }}
|
||||
>
|
||||
{inlineAvatar && avatar}
|
||||
{item.text}
|
||||
</Box>
|
||||
{props.enabled ? (
|
||||
<Box width={["full", "full", "full", "fit-content"]} maxWidth={["full", "full", "full", "2xl"]}>
|
||||
<Link href={`/messages/${item.id}`}>
|
||||
<LinkBox
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
whiteSpace="pre-line"
|
||||
>
|
||||
{item.text}
|
||||
</LinkBox>
|
||||
</Link>
|
||||
</Box>
|
||||
) : (
|
||||
<Box
|
||||
width={["full", "full", "full", "fit-content"]}
|
||||
maxWidth={["full", "full", "full", "2xl"]}
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
>
|
||||
{item.text}
|
||||
</Box>
|
||||
)}
|
||||
</HStack>
|
||||
</FlaggableElement>
|
||||
);
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import { Box, BoxProps, useColorModeValue } from "@chakra-ui/react";
|
||||
import clsx from "clsx";
|
||||
import { PropsWithChildren } from "react";
|
||||
|
||||
interface SurveyCardProps {
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const SurveyCard = (props: SurveyCardProps) => {
|
||||
export const SurveyCard = (props: PropsWithChildren<{ className?: string }>) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
|
||||
const BoxClasses: BoxProps = {
|
||||
gap: "2",
|
||||
borderRadius: "xl",
|
||||
shadow: "base",
|
||||
className: "p-4 sm:p-6",
|
||||
className: clsx("p-4 sm:p-6", props.className),
|
||||
};
|
||||
|
||||
return (
|
||||
<Box bg={backgroundColor} {...BoxClasses}>
|
||||
<Box as="section" bg={backgroundColor} {...BoxClasses}>
|
||||
{props.children}
|
||||
</Box>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
export enum TaskCategory {
|
||||
Tasks = "Tasks",
|
||||
Random = "Random",
|
||||
Create = "Create",
|
||||
Evaluate = "Evaluate",
|
||||
Label = "Label",
|
||||
@@ -20,12 +20,19 @@ export interface TaskInfo {
|
||||
unchanged_message?: string;
|
||||
}
|
||||
|
||||
export const TaskCategoryLabels: { [key in TaskCategory]: string } = {
|
||||
[TaskCategory.Random]: "I'm feeling lucky",
|
||||
[TaskCategory.Create]: "Create",
|
||||
[TaskCategory.Evaluate]: "Evaluate",
|
||||
[TaskCategory.Label]: "Label",
|
||||
};
|
||||
|
||||
export const TaskTypes: TaskInfo[] = [
|
||||
// general/random
|
||||
{
|
||||
label: "Start a Task",
|
||||
desc: "Help us improve Open Assistant by starting a random task.",
|
||||
category: TaskCategory.Tasks,
|
||||
category: TaskCategory.Random,
|
||||
pathname: "/tasks/random",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
type: "random",
|
||||
@@ -121,7 +128,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_prompter_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt",
|
||||
overview: "Given the following discussion, provide labels for the final prompt.",
|
||||
type: "label_prompter_reply",
|
||||
mode: "full",
|
||||
update_type: "text_labels",
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import { useState } from "react";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { BaseTask, TaskResponse } from "src/types/Task";
|
||||
import { BaseTask, TaskResponse, TaskType as TaskTypeEnum } from "src/types/Task";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskApiEndpoint: string) => {
|
||||
export const useGenericTaskAPI = <TaskType extends BaseTask>(taskType: TaskTypeEnum) => {
|
||||
type ConcreteTaskResponse = TaskResponse<TaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>("/api/new_task/" + taskApiEndpoint, get, {
|
||||
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>("/api/new_task/" + taskType, get, {
|
||||
onSuccess: (data) => setTasks([data]),
|
||||
revalidateOnMount: true,
|
||||
dedupingInterval: 500,
|
||||
});
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", post, {
|
||||
onSuccess: async (response) => {
|
||||
const newTask: ConcreteTaskResponse = response;
|
||||
onSuccess: async (newTask: ConcreteTaskResponse) => {
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
mutate();
|
||||
},
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
|
||||
export const getDefaultStaticProps = async ({ locale }) => ({
|
||||
props: {
|
||||
...(await serverSideTranslations(locale)),
|
||||
},
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
import { JWT } from "next-auth/jwt";
|
||||
import type { Message } from "src/types/Conversation";
|
||||
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import type { BackendUser } from "src/types/Users";
|
||||
import type { AvailableTasks } from "src/types/Task";
|
||||
import type { BackendUser, BackendUserCore } from "src/types/Users";
|
||||
|
||||
export class OasstError {
|
||||
message: string;
|
||||
@@ -108,14 +108,10 @@ export class OasstApiClient {
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to store a task in RegisteredTask.task.
|
||||
// This is a raw Json type, so we can't use it to strongly type the task.
|
||||
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
|
||||
async fetchTask(taskType: string, user: BackendUserCore): Promise<any> {
|
||||
return this.post("/api/v1/tasks/", {
|
||||
type: taskType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
user,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -140,15 +136,11 @@ export class OasstApiClient {
|
||||
messageId: string,
|
||||
userMessageId: string,
|
||||
content: object,
|
||||
userToken: JWT
|
||||
user: BackendUserCore
|
||||
): Promise<any> {
|
||||
return this.post("/api/v1/tasks/interaction", {
|
||||
type: updateType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
user,
|
||||
task_id: taskId,
|
||||
message_id: messageId,
|
||||
user_message_id: userMessageId,
|
||||
@@ -224,6 +216,13 @@ export class OasstApiClient {
|
||||
async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise<LeaderboardReply> {
|
||||
return this.get(`/api/v1/leaderboards/${time_frame}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the counts of all tasks (some might be zero)
|
||||
*/
|
||||
async fetch_available_tasks(user: BackendUserCore): Promise<AvailableTasks> {
|
||||
return this.post(`/api/v1/tasks/availability`, user);
|
||||
}
|
||||
}
|
||||
|
||||
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import prisma from "src/lib/prismadb";
|
||||
import type { BackendUserCore } from "src/types/Users";
|
||||
|
||||
/**
|
||||
* Returns a `BackendUserCore` that can be used for interacting with the Backend service.
|
||||
*
|
||||
* @param {string} id The user's web auth id.
|
||||
*
|
||||
* @return {BackendUserCore} The most specific auth type and id for the user.
|
||||
*/
|
||||
const getBackendUserCore = async (id: string) => {
|
||||
const user = await prisma.user.findUnique({
|
||||
where: { id },
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
accounts: true,
|
||||
},
|
||||
});
|
||||
|
||||
// If there are no linked accounts, just use what we have locally.
|
||||
if (user.accounts.length === 0) {
|
||||
return {
|
||||
id: user.id,
|
||||
display_name: user.name,
|
||||
auth_method: "local",
|
||||
} as BackendUserCore;
|
||||
}
|
||||
|
||||
// Otherwise, use the first linked account that the user created.
|
||||
return {
|
||||
id: user.accounts[0].providerAccountId,
|
||||
display_name: user.name,
|
||||
auth_method: user.accounts[0].provider,
|
||||
} as BackendUserCore;
|
||||
};
|
||||
|
||||
export { getBackendUserCore };
|
||||
@@ -3,6 +3,7 @@ import Head from "next/head";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { EmptyState } from "src/components/EmptyState";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
function Error() {
|
||||
return (
|
||||
|
||||
@@ -3,6 +3,7 @@ import Head from "next/head";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { EmptyState } from "src/components/EmptyState";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
function ServerError() {
|
||||
return (
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Container } from "src/components/Container";
|
||||
import Roadmap from "src/components/Roadmap";
|
||||
import Services from "src/components/Services";
|
||||
import Vision from "src/components/Vision";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const AboutPage = () => {
|
||||
return (
|
||||
|
||||
@@ -4,6 +4,7 @@ import Router from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import React from "react";
|
||||
import { Control, useForm, useWatch } from "react-hook-form";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { Button } from "@chakra-ui/react";
|
||||
import { Button, Divider, Flex, Grid, Icon, Text } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import React from "react";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
import { MdOutlineEdit } from "react-icons/md";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
@@ -19,15 +22,28 @@ export default function Account() {
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<div className="oa-basic-theme">
|
||||
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<Button>
|
||||
<Link href="/account/edit">Edit Username</Link>
|
||||
</Button>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</div>
|
||||
<main className="oa-basic-theme p-6">
|
||||
<Flex m="auto" className="max-w-7xl" alignContent="center">
|
||||
<SurveyCard className="w-full">
|
||||
<Text as="b" display="block" fontSize="2xl" py={2}>
|
||||
Your Account
|
||||
</Text>
|
||||
<Divider />
|
||||
<Grid gridTemplateColumns="repeat(2, max-content)" alignItems="center" gap={6} py={4}>
|
||||
<Text as="b">Username</Text>
|
||||
<Flex gap={2}>
|
||||
{session.user.name ?? "(No username)"}
|
||||
<Link href="/account/edit">
|
||||
<Icon boxSize={5} as={MdOutlineEdit} />
|
||||
</Link>
|
||||
</Flex>
|
||||
<Text as="b">Email</Text>
|
||||
<Text>{session.user.email ?? "(No Email)"}</Text>
|
||||
</Grid>
|
||||
<p></p>
|
||||
</SurveyCard>
|
||||
</Flex>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
import { UserTable } from "src/components/UserTable";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
/**
|
||||
* Provides the admin index page that will display a list of users and give
|
||||
|
||||
@@ -3,6 +3,7 @@ import { InferGetServerSidePropsType } from "next";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import { useEffect } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
@@ -111,7 +112,7 @@ const ManageUser = ({ user }: InferGetServerSidePropsType<typeof getServerSidePr
|
||||
/**
|
||||
* Fetch the user's data on the server side when rendering.
|
||||
*/
|
||||
export async function getServerSideProps({ query }) {
|
||||
export async function getServerSideProps({ query, locale }) {
|
||||
const backend_user = await oasstApiClient.fetch_user(query.id);
|
||||
const local_user = await prisma.user.findUnique({
|
||||
where: { id: backend_user.id },
|
||||
@@ -126,6 +127,7 @@ export async function getServerSideProps({ query }) {
|
||||
return {
|
||||
props: {
|
||||
user,
|
||||
...(await serverSideTranslations(locale, ["common"])),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
const availableTasks = await oasstApiClient.fetch_available_tasks(user);
|
||||
res.status(200).json(availableTasks);
|
||||
});
|
||||
|
||||
export default handler;
|
||||
@@ -1,6 +1,7 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
/**
|
||||
* Returns a new task created from the Task Backend. We do a few things here:
|
||||
@@ -14,9 +15,10 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
// Fetch the new task.
|
||||
const { task_type } = req.query;
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
let task;
|
||||
try {
|
||||
task = await oasstApiClient.fetchTask(task_type as string, token);
|
||||
task = await oasstApiClient.fetchTask(task_type as string, user);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
res.status(500).json(err);
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Prisma } from "@prisma/client";
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { getBackendUserCore } from "src/lib/users";
|
||||
|
||||
/**
|
||||
* Stores the task interaction with the Task Backend and then returns the next task generated.
|
||||
@@ -39,9 +40,10 @@ const handler = withoutRole("banned", async (req, res, token) => {
|
||||
},
|
||||
});
|
||||
|
||||
const user = await getBackendUserCore(token.sub);
|
||||
let newTask;
|
||||
try {
|
||||
newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, token);
|
||||
newTask = await oasstApiClient.interactTask(update_type, taskId, frontendId, interaction.id, content, user);
|
||||
} catch (err) {
|
||||
console.error(JSON.stringify(err));
|
||||
return res.status(500).json(err);
|
||||
|
||||
@@ -5,6 +5,7 @@ import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { ClientSafeProvider, getProviders, signIn } from "next-auth/react";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
@@ -47,7 +48,6 @@ interface SigninProps {
|
||||
function Signin({ providers }: SigninProps) {
|
||||
const router = useRouter();
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const [error, setError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
@@ -151,7 +151,7 @@ function Signin({ providers }: SigninProps) {
|
||||
|
||||
Signin.getLayout = (page) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
<Header />
|
||||
{page}
|
||||
<Footer />
|
||||
</div>
|
||||
@@ -209,11 +209,12 @@ const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSaf
|
||||
);
|
||||
};
|
||||
|
||||
export const getServerSideProps: GetServerSideProps<SigninProps> = async () => {
|
||||
export const getServerSideProps: GetServerSideProps<SigninProps> = async ({ locale }) => {
|
||||
const providers = await getProviders();
|
||||
return {
|
||||
props: {
|
||||
providers,
|
||||
...(await serverSideTranslations(locale, ["common"])),
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateAssistantReply } from "src/hooks/tasks/useCreateReply";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreateInitialPrompt } from "src/hooks/tasks/useCreateReply";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const InitialPrompt = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useCreatePrompterReply } from "src/hooks/tasks/useCreateReply";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const UserReply = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useMemo } from "react";
|
||||
import { LeaderboardTable, TaskOption, WelcomeCard } from "src/components/Dashboard";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { TaskCategory } from "src/components/Tasks/TaskTypes";
|
||||
import { get } from "src/lib/api";
|
||||
import type { AvailableTasks, TaskType } from "src/types/Task";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const Dashboard = () => {
|
||||
const { data } = useSWRImmutable<AvailableTasks>("/api/available_tasks", get);
|
||||
|
||||
// TODO: show only these tasks:
|
||||
const availableTasks = useMemo(() => filterAvailableTasks(data ?? {}), [data]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -13,13 +23,19 @@ const Dashboard = () => {
|
||||
</Head>
|
||||
<Flex direction="column" gap="10">
|
||||
<WelcomeCard />
|
||||
<TaskOption displayTaskCategories={[TaskCategory.Tasks]} />
|
||||
<TaskOption displayTaskCategories={[TaskCategory.Random]} />
|
||||
<LeaderboardTable />
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
Dashboard.getLayout = (page) => getDashboardLayout(page);
|
||||
Dashboard.getLayout = getDashboardLayout;
|
||||
|
||||
export default Dashboard;
|
||||
|
||||
const filterAvailableTasks = (availableTasks: Partial<AvailableTasks>) =>
|
||||
Object.entries(availableTasks)
|
||||
.filter(([_, count]) => count > 0)
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
.map(([taskType]) => taskType) as TaskType[];
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankAssistantRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankAssistantReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankInitialPromptsTask } from "src/hooks/tasks/useRankReplies";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useRankPrompterRepliesTask } from "src/hooks/tasks/useRankReplies";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const RankUserReplies = () => {
|
||||
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelAssistantReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelInitialPromptTask } from "src/hooks/tasks/useLabelingTask";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelInitialPromptTask();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useLabelPrompterReplyTask } from "src/hooks/tasks/useLabelingTask";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-u
|
||||
import Head from "next/head";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LeaderboardGridCell } from "src/components/LeaderboardGridCell";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
const Leaderboard = () => {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageLoading } from "src/components/Loading/MessageLoading";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
@@ -48,10 +49,13 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
);
|
||||
};
|
||||
|
||||
MessageDetail.getInitialProps = async ({ query }) => {
|
||||
const { id } = query;
|
||||
return { id };
|
||||
};
|
||||
|
||||
MessageDetail.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export const getServerSideProps = async ({ locale, query }) => ({
|
||||
props: {
|
||||
id: query.id,
|
||||
...(await serverSideTranslations(locale, ["common"])),
|
||||
},
|
||||
});
|
||||
|
||||
export default MessageDetail;
|
||||
|
||||
@@ -4,6 +4,7 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { get } from "src/lib/api";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const MessagesDashboard = () => {
|
||||
const boxBgColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
@@ -3,6 +3,7 @@ import Head from "next/head";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard";
|
||||
import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const PrivacyPolicy = () => {
|
||||
const backgroundColor = useColorModeValue("gray.100", "gray.800");
|
||||
|
||||
@@ -4,9 +4,10 @@ import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import { useGenericTaskAPI } from "src/hooks/tasks/useGenericTaskAPI";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
const RandomTask = () => {
|
||||
const { tasks, isLoading, trigger, reset } = useGenericTaskAPI("random");
|
||||
const { tasks, isLoading, trigger, reset } = useGenericTaskAPI(TaskType.random);
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
|
||||
@@ -3,6 +3,7 @@ import Head from "next/head";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
import { PolicyChapterCard } from "src/components/PolicyCards/PolicyChapterCard";
|
||||
import { PolicySectionCard } from "src/components/PolicyCards/PolicySectionCard";
|
||||
export { getDefaultStaticProps as getStaticProps } from "src/lib/default_static_props";
|
||||
|
||||
const TermsOfService = () => {
|
||||
const TermsData = [
|
||||
|
||||
@@ -10,6 +10,8 @@ export const enum TaskType {
|
||||
label_initial_prompt = "label_initial_prompt",
|
||||
label_prompter_reply = "label_prompter_reply",
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
|
||||
random = "random",
|
||||
}
|
||||
|
||||
// we need to reconsider how to handle task content types
|
||||
@@ -32,3 +34,5 @@ export interface TaskResponse<Task extends BaseTask> {
|
||||
userId: string;
|
||||
task: Task;
|
||||
}
|
||||
|
||||
export type AvailableTasks = { [taskType in TaskType]: number };
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
/**
|
||||
* Reports the Backend's knowledge of a user.
|
||||
*/
|
||||
export interface BackendUser {
|
||||
export interface BackendUserCore {
|
||||
/**
|
||||
* The user's unique ID according to the `auth_method`.
|
||||
*/
|
||||
@@ -18,7 +15,12 @@ export interface BackendUser {
|
||||
* - local
|
||||
*/
|
||||
auth_method: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reports the Backend's knowledge of a user.
|
||||
*/
|
||||
export interface BackendUser extends BackendUserCore {
|
||||
/**
|
||||
* The backend's UUID for this user.
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user