Merge branch 'LAION-AI:main' into main

This commit is contained in:
rasdani
2023-01-05 11:35:25 +01:00
committed by GitHub
51 changed files with 784 additions and 255 deletions
+9 -1
View File
@@ -15,6 +15,9 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: actions/setup-node@v3
with:
node-version: 16
- run: cd oasst-shared && pip install -e .
@@ -22,9 +25,14 @@ jobs:
- run: cd backend && pip install -r requirements.txt
- run: cd website && npm install
- run: ./scripts/backend-development/start-mock-server.sh
- name: Run contract tests
- name: Run Python OasstApiClient contract tests
run: ./scripts/oasst-shared-development/test.sh
- name: Run JavaScript OasstApiClient contract tests
run: ./scripts/frontend-development/run-contract-test.sh
- run: ./scripts/backend-development/stop-mock-server.sh
+110
View File
@@ -0,0 +1,110 @@
# Im in! Now what?
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
this is for work coordination.
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
### Taking on Tasks
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
contributions smoothly we recommend the following:
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
and clone it to your local machine. (Read more
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
1. Before working on any changes, try to
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
to keep it up-to-date with the upstream repository.
1. Work on a small focused change that only touches on a few files.
1. Run `pre-commit` and make sure all files have formatting fixed. This
simplifies life for reviewers.
1. Package up a small bit of work that solves part of the problem
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
and
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and merge into `main`.
1. Merge in your change and move onto a new issue or the second step of your
current issue.
Additionally, if someone is working on an issue that interests you, ask if they
need help on it or would like suggestions on how to approach the issue. If so,
share wildly. If they seem to have a good handle on it, let them work on their
solution until a challenge comes up.
### When does a review finish
A review finishes when all blocking comments are addressed and at least one
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
comments either by making the request change, explaining why it's not being
addressed now, or filing an issue to handle it later.
## Developer Setup
Work is organized in the
[project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs.
Meaning we'd be happy for anyone to do these tasks.**
If you want to work on something, assign yourself to it or write a comment that
you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a
look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at
`scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend`
folder.
We are using Python 3.10 for the backend.
Check out the
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
+33 -126
View File
@@ -1,7 +1,24 @@
# Open-Assistant
<h1 align="center">
<span>Open-Assistant</span>
<img width="auto" height="50px" src="https://github.com/LAION-AI/Open-Assistant/blob/main/assets/logo_crop.png"/>
</h1>
Open Assistant is a project meant to give everyone access to a great chat based
large language model.
# Table of Contents
- [What is Open Assistant?](#what-is-open-assistant)
- [Do you want to try it out?](#do-you-want-to-try-it-out)
- [The Plan](#the-plan)
- [The Vision](#the-vision)
- [How can you help?](#how-can-you-help)
- [Im in! How do I contribute?](CONTRIBUTING.md)
---
## What is Open Assistant?
<p align="center">
Open Assistant is a project meant to give everyone access to a great chat based large language model.
</p>
We believe that by doing this we will create a revolution in innovation in
language. In the same way that stable-diffusion helped the world make art and
@@ -14,7 +31,7 @@ If you are interested in taking a look at the current state of the project, you
can set up an entire stack needed to run **Open-Assistant**, including the
website, backend, and associated dependent services.
To start the demo, run this in the root directory of the repository:
##### To start the demo, run this in the root directory of the repository:
```sh
docker compose up --build
@@ -23,22 +40,21 @@ docker compose up --build
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
interact with the website.
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
the magic email login link.
> **Note:** When logging in via email, navigate to `http://localhost:1080` to
> get the magic email login link.
**Note:** If you would like to run this in a standardized development
environment (a
["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
using
[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
or in a web browser using
[GitHub Codespaces](https://github.com/features/codespaces), you can use the
provided [`.devcontainer`](.devcontainer/) folder.
> **Note:** If you would like to run this in a standardized development
> environment (a
> ["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
> using
> [vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
> or in a web browser using
> [GitHub Codespaces](https://github.com/features/codespaces), you can use the
> provided [`.devcontainer`](.devcontainer/) folder.
## The Plan
We want to get to an initial MVP as fast as possible, by following the 3-steps
outlined in the InstructGPT paper.
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
1. Collect high-quality human generated Instruction-Fulfillment samples
(prompt + response), goal >50k. We design a crowdsourced process to collect
@@ -80,113 +96,4 @@ All open source projects begin with people like you. Open source is the belief
that if we collaborate we can together gift our knowledge and technology to the
world for the benefit of humanity.
## Im in! Now what?
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
this is for work coordination.
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
### Taking on Tasks
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
contributions smoothly we recommend the following:
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
and clone it to your local machine. (Read more
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
1. Before working on any changes, try to
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
to keep it up-to-date with the upstream repository.
1. Work on a small focused change that only touches on a few files.
1. Run `pre-commit` and make sure all files have formatting fixed. This
simplifies life for reviewers.
1. Package up a small bit of work that solves part of the problem
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
and
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and merge into `main`.
1. Merge in your change and move onto a new issue or the second step of your
current issue.
Additionally, if someone is working on an issue that interests you, ask if they
need help on it or would like suggestions on how to approach the issue. If so,
share wildly. If they seem to have a good handle on it, let them work on their
solution until a challenge comes up.
### When does a review finish
A review finishes when all blocking comments are addressed and at least one
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
comments either by making the request change, explaining why it's not being
addressed now, or filing an issue to handle it later.
## Developer Setup
Work is organized in the
[project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs.
Meaning we'd be happy for anyone to do these tasks.**
If you want to work on something, assign yourself to it or write a comment that
you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a
look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at
`scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend`
folder.
We are using Python 3.10 for the backend.
Check out the
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
Check out our [contributing guide](CONTRIBUTING.md) to get started.
+1
View File
@@ -82,6 +82,7 @@
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
ports:
- 8080:8080
Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

+4 -2
View File
@@ -11,10 +11,12 @@ Example contents of a `.env` file for the backend:
```
DATABASE_URI="postgresql://<username>:<password>@<host>/<database_name>"
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
REDIS_HOST=localhost
REDIS_PORT=6379
```
## Running the REST Server locally for development
Have a look into the main `README.md` file for more information on how to set up
the backend for development.
the backend for development. Use the scripts within the
scripts/backend-development folder to run the BE API locally.
+1
View File
@@ -36,3 +36,4 @@ environments:
secrets:
# Note: URI, not URL.
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
REDIS_HOST: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/REDIS_HOST
@@ -0,0 +1,16 @@
model_name: kalpeshk2011/rankgen-t5-base-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: true
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,19 @@
model_name: kalpeshk2011/rankgen-t5-base-all
# model_name: kalpeshk2011/rankgen-t5-xl-all
# model_name: kalpeshk2011/rankgen-t5-xl-pg19
# model_name: kalpeshk2011/rankgen-t5-large-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
+27
View File
@@ -0,0 +1,27 @@
import torch
from transformers import AutoModel
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in [
"kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all",
]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products
+30
View File
@@ -22,11 +22,41 @@ from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@dataclass
class RankGenCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
for question, pairs in batch:
for (pos, neg) in pairs:
prefixes.append("pre " + question)
better_answers.append("suffi " + pos)
worse_answers.append("suffi " + neg)
tokenized_prefixes = self.tokenizer(
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_pos = self.tokenizer(
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_neg = self.tokenizer(
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
@dataclass
class DataCollatorForPairRank:
"""
+2 -1
View File
@@ -1,6 +1,7 @@
datasets==2.8.0
evaluate==0.4.0
scikit-learn==1.2.0
torch==1.12.1+cu116
sentencepiece==0.1.97
torch>=1.12.1
transformers==4.25.1
wandb==0.13.7
+57 -26
View File
@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
@@ -46,14 +47,16 @@ class RankLoss(nn.Module):
self.log_sigmoid = nn.LogSigmoid()
def forward(self, pos, neg):
return -self.log_sigmoid(pos - neg + self.eps).mean()
loss = -self.log_sigmoid(pos - neg + self.eps).mean()
return loss
class RankTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
@@ -79,15 +82,25 @@ class RankTrainer(Trainer):
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
if "rankgen" in self.model_name:
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
return (loss, outputs) if return_outputs else loss
@@ -109,24 +122,37 @@ class RankTrainer(Trainer):
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.inference_mode():
if "rankgen" in self.model_name:
inputs = self._prepare_inputs(inputs)
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
return (loss, outputs, None)
else:
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
with torch.no_grad():
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
return (loss, logits, labels)
if __name__ == "__main__":
training_conf = argument_parsing(parser)
model_name = training_conf["model_name"]
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "rankgen-t5" in model_name:
model = RankGenModel(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "freeze_layer" in training_conf:
num_layer = training_conf["freeze_layer"]
model = freeze_top_n_layers(model, num_layer)
@@ -134,7 +160,6 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
tokenizer = get_tokenizer(model_name)
args = CustomTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
@@ -142,7 +167,7 @@ if __name__ == "__main__":
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=True,
fp16=training_conf["fp16"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
@@ -154,7 +179,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
report_to="wandb",
report_to="local",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -169,17 +194,23 @@ if __name__ == "__main__":
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
collate_fn = DataCollatorForPairRank(
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
trainer = RankTrainer(
model,
args,
model=model,
model_name=model_name,
args=args,
train_dataset=train,
eval_dataset=eval,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# trainer.evaluate()
trainer.train()
+11 -5
View File
@@ -3,7 +3,7 @@ import re
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
@@ -25,7 +25,10 @@ def webgpt_return_format(row):
def get_tokenizer(tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "t5" in tokenizer_name: # rankgen
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "galactica" in tokenizer_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
@@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers):
def argument_parsing(parser):
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
default_params = {
"num_train_epochs": 4,
"learning_rate": 3e-5,
@@ -78,10 +85,9 @@ def argument_parsing(parser):
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"datasets": ["webgpt"],
"fp16": True,
"tokenizer_name": training_conf["model_name"],
}
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
+2 -2
View File
@@ -33,6 +33,6 @@ Experimental results in wandb
## TODOS
- decide on a model
- add special token to declare prompt and reply. Do nto freeze the weights for
these
- Merge utils etc with reward model
- Casual Modelling for GPT-JT does not leverage the bidirectional mask for the
prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1)
@@ -32,6 +32,17 @@ galactica-125:
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
gpt-jt:
learning_rate: 2e-6
model_name: togethercomputer/GPT-JT-6B-v1
weight_decay: 0.01
max_length: 1024
warmup_steps: 600
gradient_checkpointing: false
gradient_accumulation_steps: 2
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
debug:
eval_steps: 20
eval_size: 100
@@ -2,6 +2,8 @@ from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
@@ -6,6 +6,8 @@ import torch
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from . import QA_SPECIAL_TOKENS
@dataclass
class DialogueDataCollator:
@@ -19,22 +21,21 @@ class DialogueDataCollator:
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
# TODO add special tokens for question and answer here
# additional_special_tokens = ['<question>', '<answer>']
prompt_tokens = ["Question: ", "Answer: "]
flatten_messages = []
label_masks = []
for messages in features:
assert len(messages) % 2 == 0, "Number of messages must be even"
messages = [
(prompt_tokens[0] if i % 2 == 0 else "") + x + ((" " + prompt_tokens[1]) if i % 2 == 0 else "")
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ x
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
for i, x in enumerate(messages)
]
# Add a way for the model to terminate generation, reinitialize prompter
messages.append(prompt_tokens[0])
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_messages.append(
self.tokenizer(
@@ -47,8 +48,10 @@ class DialogueDataCollator:
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John.
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3
# Label mask is true when predicting a token that is part of the answer, false otherwise.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
# If no result in next, we are predicting the last termination token(s)
message_indices = list(
@@ -0,0 +1,6 @@
datasets==2.8.0
numpy==1.23.0
PyYAML==6.0
scikit_learn==1.2.0
torch==1.13.1
transformers==4.25.1
+4 -2
View File
@@ -67,6 +67,8 @@ class SFTTrainer(Trainer):
optimizers,
preprocess_logits_for_metrics,
)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(args.loss_function)
def fetch_scheduler(self):
@@ -112,7 +114,7 @@ class SFTTrainer(Trainer):
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
labels[~labels_mask] = -1
labels[~labels_mask] = -100 # padding_index
loss = loss.mean().detach()
@@ -159,8 +161,8 @@ def argument_parsing(notebook=False, notebook_args=None):
if __name__ == "__main__":
training_conf = argument_parsing()
model = get_model(training_conf)
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
+17 -3
View File
@@ -1,14 +1,14 @@
from pathlib import Path
import yaml
from custom_datasets import get_one_dataset
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from transformers import AutoModelForCausalLM, AutoTokenizer
SUPPORTED_MODELS = ["galactica"]
SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated ..
def get_tokenizer(conf):
@@ -17,10 +17,19 @@ def get_tokenizer(conf):
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
additional_special_tokens = (
[]
if "additional_special_tokens" not in tokenizer.special_tokens_map
else tokenizer.special_tokens_map["additional_special_tokens"]
)
additional_special_tokens = list(set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
return tokenizer
def get_model(conf):
def get_model(conf, tokenizer):
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
raise ValueError(
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
@@ -29,6 +38,11 @@ def get_model(conf):
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
model.resize_token_embeddings(len(tokenizer))
if conf.freeze_layer:
model = freeze_top_n_layers(model, conf.freeze_layer)
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import pydantic
from oasst_shared.exceptions import OasstErrorCode
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, conint, conlist, constr
class TaskRequestType(str, enum.Enum):
@@ -203,7 +203,7 @@ class TextReplyToMessage(Interaction):
type: Literal["text_reply_to_message"] = "text_reply_to_message"
message_id: str
user_message_id: str
text: str
text: constr(min_length=1, strip_whitespace=True)
class MessageRating(Interaction):
@@ -211,7 +211,7 @@ class MessageRating(Interaction):
type: Literal["message_rating"] = "message_rating"
message_id: str
rating: int
rating: conint(gt=0)
class MessageRanking(Interaction):
@@ -219,7 +219,7 @@ class MessageRanking(Interaction):
type: Literal["message_ranking"] = "message_ranking"
message_id: str
ranking: list[int]
ranking: conlist(item_type=int, min_items=1)
AnyInteraction = Union[
+11
View File
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to website directory
pushd "$parent_path/../../website"
set -xe
npm run cypress:run:contract
popd
+9
View File
@@ -0,0 +1,9 @@
import { defineConfig } from "cypress";
export default defineConfig({
e2e: {
// No baseUrl here, because we don't need it for contract testing
baseUrl: null,
specPattern: "cypress/contract/*.cy.{ts,js}",
},
});
@@ -0,0 +1,29 @@
import OasstApiClient from "src/lib/oasst_api_client";
describe("Contract test for Oasst API", function () {
// Assumes this is running the mock server.
const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
it("can fetch a task", async () => {
expect(
await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
})
).to.be.not.null;
});
it("can ack a task", async () => {
const task = await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
});
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
});
// TODO(#354): Add test for 204
// TODO(#354): Add test for parsing >=300, throwing an OasstError
// TODO(#354): Add test for parsing >=300, throwing a generic error
});
+8
View File
@@ -2,6 +2,14 @@
const nextConfig = {
output: "standalone",
reactStrictMode: true,
images: {
remotePatterns: [
{
protocol: "https",
hostname: "**.discordapp.com",
},
],
},
experimental: {
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
If this can be resolved, we should re-enable this.
+1
View File
@@ -12,6 +12,7 @@
"build-storybook": "build-storybook",
"cypress": "cypress open",
"cypress:run": "cypress run",
"cypress:run:contract": "cypress run --config-file ./cypress.config.contract.js",
"cypress:image-baseline": "cypress-image-diff -u",
"fix:lint": "eslint --fix src/ --ext .js,.jsx,.ts,.tsx",
"fix:format": "prettier --write ./src",
@@ -1,6 +1,6 @@
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { useRouter } from "next/router";
import { FiLayout, FiSun } from "react-icons/fi";
import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
import { colors } from "styles/Theme/colors";
export function SideMenu() {
@@ -13,6 +13,12 @@ export function SideMenu() {
desc: "Dashboard Home",
icon: FiLayout,
},
{
label: "Messages",
pathname: "/messages",
desc: "Messages Dashboard",
icon: FiMessageSquare,
},
// {
// label: "Leaderboard",
// pathname: "#",
@@ -2,11 +2,17 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } fro
import Link from "next/link";
const crTasks = [
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
type: "create",
pathname: "/create/initial_prompt",
},
{
label: "Reply as User",
desc: "Chat with Open Assistant and help improve its responses as you interact with it.",
type: "create",
pathname: "/create/assistant_reply",
pathname: "/create/user_reply",
},
{
label: "Reply as Assistant",
+10 -5
View File
@@ -15,6 +15,7 @@ import {
SliderThumb,
SliderTrack,
Spacer,
Tooltip,
useBoolean,
useId,
} from "@chakra-ui/react";
@@ -69,11 +70,15 @@ export const FlaggableElement = (props) => {
>
<Grid templateColumns="1fr min-content" gap={2}>
<PopoverAnchor>{props.children}</PopoverAnchor>
<PopoverTrigger>
<Button h="full">
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
</Button>
</PopoverTrigger>
<Tooltip hasArrow label="Report" bg="red.600">
<div>
<PopoverTrigger>
<Button h="full">
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
</Button>
</PopoverTrigger>
</div>
</Tooltip>
</Grid>
<PopoverContent width="fit-content">
+1 -1
View File
@@ -42,7 +42,7 @@ export function UserMenu() {
className="flex items-center gap-4 p-1 lg:pr-6 rounded-full transition-colors duration-300"
>
<Image
src="/images/temp-avatars/av5.jpg"
src={session.user.image || "/images/temp-avatars/av1.jpg"}
alt="Profile Picture"
width="36"
height="36"
+7
View File
@@ -25,4 +25,11 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
</div>
);
export const getDashboardLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
</div>
);
export const noLayout = (page: React.ReactElement) => page;
@@ -0,0 +1,12 @@
import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
import { MessageTableEntry } from "./MessageTableEntry";
export function MessageTable({ messages }) {
return (
<Stack divider={<StackDivider />} spacing="4">
{messages.map((item, idx) => (
<MessageTableEntry item={item} idx={idx} key={item.id} />
))}
</Stack>
);
}
@@ -0,0 +1,24 @@
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import NextLink from "next/link";
import { FlaggableElement } from "../FlaggableElement";
export function MessageTableEntry({ item, idx }) {
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
return (
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
<HStack>
<Avatar
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
/>
<LinkBox className={`p-4 rounded-md text-white whitespace-pre-wrap ${bgColor} text-white w-full`}>
<NextLink href={`/messages/${item.id}`} passHref>
{item.text}
</NextLink>
</LinkBox>
</HStack>
</FlaggableElement>
);
}
+64
View File
@@ -0,0 +1,64 @@
import { JWT } from "next-auth/jwt";
class OasstError {
message: string;
errorCode: number;
httpStatusCode: number;
constructor(message: string, errorCode: number, httpStatusCode: number) {
this.message = message;
this.errorCode = errorCode;
this.httpStatusCode = httpStatusCode;
}
}
export default class OasstApiClient {
constructor(private readonly oasstApiUrl: string, private readonly oasstApiKey: string) {}
private async post(path: string, body: any): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "POST",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
});
if (resp.status == 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
try {
const error = JSON.parse(errorText);
throw new OasstError(error.message, error.error_code, resp.status);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
}
return await resp.json();
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
// This is a raw Json type, so we can't use it to strongly type the task.
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
return this.post("/api/v1/tasks/", {
type: taskType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
auth_method: "local",
},
});
}
async ackTask(taskId: string, messageId: string): Promise<void> {
return this.post(`/api/v1/tasks/${taskId}/ack`, {
message_id: messageId,
});
}
}
+18 -16
View File
@@ -36,22 +36,24 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<form onSubmit={updateUser}>
<InputGroup>
<Input
onChange={(e) => setUsername(e.target.value)}
placeholder="Edit Username"
type="text"
value={username}
></Input>
<Button disabled={!username} type="submit" value="Change">
Submit
</Button>
</InputGroup>
</form>
</main>
<div className="oa-basic-theme">
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<form onSubmit={updateUser}>
<InputGroup>
<Input
onChange={(e) => setUsername(e.target.value)}
placeholder="Edit Username"
type="text"
value={username}
></Input>
<Button disabled={!username} type="submit" value="Change">
Submit
</Button>
</InputGroup>
</form>
</main>
</div>
</>
);
}
+9 -7
View File
@@ -19,13 +19,15 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
<div className="oa-basic-theme">
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
</div>
</>
);
}
+24
View File
@@ -0,0 +1,24 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
+29
View File
@@ -0,0 +1,29 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
//TODO: add params if needed
const params = new URLSearchParams({
username: token.sub,
});
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
+5 -28
View File
@@ -1,4 +1,5 @@
import { getToken } from "next-auth/jwt";
import OasstApiClient from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
/**
@@ -20,25 +21,10 @@ const handler = async (req, res) => {
return;
}
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
// Fetch the new task.
//
// This needs to be refactored into an easier to use library.
const taskRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
type: task_type,
user: {
id: token.sub,
display_name: token.name || token.email,
auth_method: "local",
},
}),
});
const task = await taskRes.json();
const task = await oasstApiClient.fetchTask(task_type, token);
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -53,16 +39,7 @@ const handler = async (req, res) => {
});
// Update the backend with our Task ID
await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
message_id: registeredTask.id,
}),
});
await oasstApiClient.ackTask(task.id, registeredTask.id);
// Send the results to the client.
res.status(200).json(registeredTask);
+1
View File
@@ -36,6 +36,7 @@ const handler = async (req, res) => {
// Send the interaction to the Task Backend. This automatically fetches the
// next task in the sequence (or the done task).
// TODO(#353): Move this into OasstApiClient.
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/interaction`, {
method: "POST",
headers: {
+7 -1
View File
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const AssistantReply = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
+7 -1
View File
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
@@ -27,6 +27,12 @@ const InitialPrompt = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
+7 -1
View File
@@ -1,6 +1,6 @@
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const UserReply = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
+3 -7
View File
@@ -1,6 +1,7 @@
import { Box, useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { Header } from "src/components/Header";
import { getDashboardLayout } from "src/components/Layout";
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
import { colors } from "styles/Theme/colors";
@@ -27,11 +28,6 @@ const Dashboard = () => {
);
};
Dashboard.getLayout = (page) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
</div>
);
Dashboard.getLayout = (page) => getDashboardLayout(page);
export default Dashboard;
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
@@ -26,6 +26,12 @@ const RankAssistantReplies = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
@@ -32,6 +32,12 @@ const RankInitialPrompts = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const submitResponse = (task) => {
trigger({
id: task.id,
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
@@ -26,6 +26,12 @@ const RankUserReplies = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
+83
View File
@@ -0,0 +1,83 @@
import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import useSWRImmutable from "swr/immutable";
import fetcher from "src/lib/fetcher";
import { SideMenu } from "src/components/Dashboard";
import { MessageTable } from "src/components/Messages/MessageTable";
import { getDashboardLayout } from "src/components/Layout";
import { colors } from "styles/Theme/colors";
const MessagesDashboard = () => {
const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg);
const boxBgColor = useColorModeValue("white", "gray.700");
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
const [messages, setMessages] = useState([]);
const [userMessages, setUserMessages] = useState([]);
const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, {
onSuccess: (data) => {
setMessages(data);
},
});
const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
onSuccess: (data) => {
setUserMessages(data);
},
});
return (
<>
<Head>
<title>Messages - Open Assistant</title>
<meta name="description" content="Chat with Open Assistant and provide feedback." />
</Head>
<Box backgroundColor={bgColor} className="sm:overflow-hidden">
<Box className="sm:flex h-full gap-6">
<Box className="p-6 sm:pr-0">
<SideMenu />
</Box>
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
<Box>
<Text className="text-2xl font-bold" pb="4">
Most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
</Box>
</Box>
<Box>
<Text className="text-2xl font-bold" pb="4">
Your most recent messages
</Text>
<Box
backgroundColor={boxBgColor}
boxShadow="base"
dropShadow={boxAccentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
</Box>
</Box>
</SimpleGrid>
</Box>
</Box>
</Box>
</>
);
};
MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
export default MessagesDashboard;
+1 -1
View File
@@ -14,7 +14,7 @@ const PrivacyPolicy = () => {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main>
<main className="oa-basic-theme">
<Container>
<Heading as="h1" size="3xl">
Privacy Policy
+1 -1
View File
@@ -14,7 +14,7 @@ const TermsOfService = () => {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main>
<main className="oa-basic-theme">
<Container>
<Heading as="h1" size="3xl">
Terms Of Service