diff --git a/.github/workflows/test-api-contract.yaml b/.github/workflows/test-api-contract.yaml
index 3707f4de..4ca36da0 100644
--- a/.github/workflows/test-api-contract.yaml
+++ b/.github/workflows/test-api-contract.yaml
@@ -15,6 +15,9 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.10"
+ - uses: actions/setup-node@v3
+ with:
+ node-version: 16
- run: cd oasst-shared && pip install -e .
@@ -22,9 +25,14 @@ jobs:
- run: cd backend && pip install -r requirements.txt
+ - run: cd website && npm install
+
- run: ./scripts/backend-development/start-mock-server.sh
- - name: Run contract tests
+ - name: Run Python OasstApiClient contract tests
run: ./scripts/oasst-shared-development/test.sh
+ - name: Run JavaScript OasstApiClient contract tests
+ run: ./scripts/frontend-development/run-contract-test.sh
+
- run: ./scripts/backend-development/stop-mock-server.sh
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..608afe25
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,110 @@
+# I’m in! Now what?
+
+[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
+this is for work coordination.
+
+[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
+a dedicated channel and is more public.
+
+[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
+dedicated, but not as active, channel.
+
+[Visit the Notion](https://ykilcher.com/open-assistant)
+
+### Taking on Tasks
+
+We have a growing task list
+[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
+that appeals to you and make a comment that you'd like to work on it. Include in
+your comment a brief description of how you'll solve the problem and if there
+are any open questions you want to discuss. Once a project coordinator has
+assigned the issue to you, start working on it.
+
+If the issue is currently unclear but you are interested, please post in Discord
+and someone can help clarify the issue with more detail.
+
+**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
+the system architecture, and other documentation.
+
+### Submitting Work
+
+We're all working on different parts of Open Assistant together. To make
+contributions smoothly we recommend the following:
+
+1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
+ and clone it to your local machine. (Read more
+ [About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
+1. Before working on any changes, try to
+ [sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
+ to keep it up-to-date with the upstream repository.
+1. Work on a small focused change that only touches on a few files.
+1. Run `pre-commit` and make sure all files have formatting fixed. This
+ simplifies life for reviewers.
+1. Package up a small bit of work that solves part of the problem
+ [into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
+ and
+ [send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
+1. If you're lucky, we can merge your change into `main` without any problems.
+ If there's changes to files you're working on, resolve them by:
+1. First try rebase as suggested
+ [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
+1. If rebase feels too painful, merge as suggested
+ [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
+1. Once you've resolved any conflicts, finish the review and merge into `main`.
+1. Merge in your change and move onto a new issue or the second step of your
+ current issue.
+
+Additionally, if someone is working on an issue that interests you, ask if they
+need help on it or would like suggestions on how to approach the issue. If so,
+share wildly. If they seem to have a good handle on it, let them work on their
+solution until a challenge comes up.
+
+### When does a review finish
+
+A review finishes when all blocking comments are addressed and at least one
+owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
+comments either by making the request change, explaining why it's not being
+addressed now, or filing an issue to handle it later.
+
+## Developer Setup
+
+Work is organized in the
+[project board](https://github.com/orgs/LAION-AI/projects/3).
+
+**Anything that is in the `Todo` column and not assigned, is up for grabs.
+Meaning we'd be happy for anyone to do these tasks.**
+
+If you want to work on something, assign yourself to it or write a comment that
+you want to work on it and what you plan to do.
+
+- To get started with development, if you want to work on the backend, have a
+ look at `scripts/backend-development/README.md`.
+- If you want to work on any frontend, have a look at
+ `scripts/frontend-development/README.md` to make a backend available.
+
+There is also a minimal implementation of a frontend in the `text-frontend`
+folder.
+
+We are using Python 3.10 for the backend.
+
+Check out the
+[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
+
+### Website
+
+The website is built using Next.js and is in the `website` folder.
+
+### Pre-commit
+
+Install `pre-commit` and run `pre-commit install` to install the pre-commit
+hooks.
+
+In case you haven't done this, have already committed, and CI is failing, you
+can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
+
+### Deployment
+
+Upon making a release on GitHub, all docker images are automatically built and
+pushed to ghcr.io. The docker images are tagged with the release version, and
+the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
+automatically deploy the built release to the dev machine.
diff --git a/README.md b/README.md
index b619c931..d612940a 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,24 @@
-# Open-Assistant
+
+ Open-Assistant
+
+
-Open Assistant is a project meant to give everyone access to a great chat based
-large language model.
+# Table of Contents
+
+- [What is Open Assistant?](#what-is-open-assistant)
+- [Do you want to try it out?](#do-you-want-to-try-it-out)
+- [The Plan](#the-plan)
+- [The Vision](#the-vision)
+- [How can you help?](#how-can-you-help)
+- [I’m in! How do I contribute?](CONTRIBUTING.md)
+
+---
+
+## What is Open Assistant?
+
+
+ Open Assistant is a project meant to give everyone access to a great chat based large language model.
+
We believe that by doing this we will create a revolution in innovation in
language. In the same way that stable-diffusion helped the world make art and
@@ -14,7 +31,7 @@ If you are interested in taking a look at the current state of the project, you
can set up an entire stack needed to run **Open-Assistant**, including the
website, backend, and associated dependent services.
-To start the demo, run this in the root directory of the repository:
+##### To start the demo, run this in the root directory of the repository:
```sh
docker compose up --build
@@ -23,22 +40,21 @@ docker compose up --build
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
interact with the website.
-**Note:** When logging in via email, navigate to `http://localhost:1080` to get
-the magic email login link.
+> **Note:** When logging in via email, navigate to `http://localhost:1080` to
+> get the magic email login link.
-**Note:** If you would like to run this in a standardized development
-environment (a
-["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
-using
-[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
-or in a web browser using
-[GitHub Codespaces](https://github.com/features/codespaces), you can use the
-provided [`.devcontainer`](.devcontainer/) folder.
+> **Note:** If you would like to run this in a standardized development
+> environment (a
+> ["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
+> using
+> [vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
+> or in a web browser using
+> [GitHub Codespaces](https://github.com/features/codespaces), you can use the
+> provided [`.devcontainer`](.devcontainer/) folder.
## The Plan
-We want to get to an initial MVP as fast as possible, by following the 3-steps
-outlined in the InstructGPT paper.
+##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
1. Collect high-quality human generated Instruction-Fulfillment samples
(prompt + response), goal >50k. We design a crowdsourced process to collect
@@ -80,113 +96,4 @@ All open source projects begin with people like you. Open source is the belief
that if we collaborate we can together gift our knowledge and technology to the
world for the benefit of humanity.
-## I’m in! Now what?
-
-[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
-this is for work coordination.
-
-[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
-a dedicated channel and is more public.
-
-[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
-dedicated, but not as active, channel.
-
-[Visit the Notion](https://ykilcher.com/open-assistant)
-
-### Taking on Tasks
-
-We have a growing task list
-[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
-that appeals to you and make a comment that you'd like to work on it. Include in
-your comment a brief description of how you'll solve the problem and if there
-are any open questions you want to discuss. Once a project coordinator has
-assigned the issue to you, start working on it.
-
-If the issue is currently unclear but you are interested, please post in Discord
-and someone can help clarify the issue with more detail.
-
-**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
-the system architecture, and other documentation.
-
-### Submitting Work
-
-We're all working on different parts of Open Assistant together. To make
-contributions smoothly we recommend the following:
-
-1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
- and clone it to your local machine. (Read more
- [About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
-1. Before working on any changes, try to
- [sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
- to keep it up-to-date with the upstream repository.
-1. Work on a small focused change that only touches on a few files.
-1. Run `pre-commit` and make sure all files have formatting fixed. This
- simplifies life for reviewers.
-1. Package up a small bit of work that solves part of the problem
- [into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
- and
- [send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
-1. If you're lucky, we can merge your change into `main` without any problems.
- If there's changes to files you're working on, resolve them by:
-1. First try rebase as suggested
- [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
-1. If rebase feels too painful, merge as suggested
- [in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
-1. Once you've resolved any conflicts, finish the review and merge into `main`.
-1. Merge in your change and move onto a new issue or the second step of your
- current issue.
-
-Additionally, if someone is working on an issue that interests you, ask if they
-need help on it or would like suggestions on how to approach the issue. If so,
-share wildly. If they seem to have a good handle on it, let them work on their
-solution until a challenge comes up.
-
-### When does a review finish
-
-A review finishes when all blocking comments are addressed and at least one
-owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
-comments either by making the request change, explaining why it's not being
-addressed now, or filing an issue to handle it later.
-
-## Developer Setup
-
-Work is organized in the
-[project board](https://github.com/orgs/LAION-AI/projects/3).
-
-**Anything that is in the `Todo` column and not assigned, is up for grabs.
-Meaning we'd be happy for anyone to do these tasks.**
-
-If you want to work on something, assign yourself to it or write a comment that
-you want to work on it and what you plan to do.
-
-- To get started with development, if you want to work on the backend, have a
- look at `scripts/backend-development/README.md`.
-- If you want to work on any frontend, have a look at
- `scripts/frontend-development/README.md` to make a backend available.
-
-There is also a minimal implementation of a frontend in the `text-frontend`
-folder.
-
-We are using Python 3.10 for the backend.
-
-Check out the
-[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
-
-### Website
-
-The website is built using Next.js and is in the `website` folder.
-
-### Pre-commit
-
-Install `pre-commit` and run `pre-commit install` to install the pre-commit
-hooks.
-
-In case you haven't done this, have already committed, and CI is failing, you
-can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
-
-### Deployment
-
-Upon making a release on GitHub, all docker images are automatically built and
-pushed to ghcr.io. The docker images are tagged with the release version, and
-the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
-automatically deploy the built release to the dev machine.
+Check out our [contributing guide](CONTRIBUTING.md) to get started.
diff --git a/ansible/dev.yaml b/ansible/dev.yaml
index c9195966..ca2a11d9 100644
--- a/ansible/dev.yaml
+++ b/ansible/dev.yaml
@@ -82,6 +82,7 @@
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
+ RATE_LIMIT: "false"
ports:
- 8080:8080
diff --git a/assets/logo_crop.png b/assets/logo_crop.png
new file mode 100644
index 00000000..20630d6b
Binary files /dev/null and b/assets/logo_crop.png differ
diff --git a/backend/README.md b/backend/README.md
index 45d16d68..863090f7 100644
--- a/backend/README.md
+++ b/backend/README.md
@@ -11,10 +11,12 @@ Example contents of a `.env` file for the backend:
```
DATABASE_URI="postgresql://:@/"
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
-
+REDIS_HOST=localhost
+REDIS_PORT=6379
```
## Running the REST Server locally for development
Have a look into the main `README.md` file for more information on how to set up
-the backend for development.
+the backend for development. Use the scripts within the
+scripts/backend-development folder to run the BE API locally.
diff --git a/copilot/api/manifest.yml b/copilot/api/manifest.yml
index b9262b51..b6ff6cf7 100644
--- a/copilot/api/manifest.yml
+++ b/copilot/api/manifest.yml
@@ -36,3 +36,4 @@ environments:
secrets:
# Note: URI, not URL.
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
+ REDIS_HOST: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/REDIS_HOST
diff --git a/model/reward/instructor/configs/rankgen-t5-base-fp16.yml b/model/reward/instructor/configs/rankgen-t5-base-fp16.yml
new file mode 100644
index 00000000..c6f2a5e0
--- /dev/null
+++ b/model/reward/instructor/configs/rankgen-t5-base-fp16.yml
@@ -0,0 +1,16 @@
+model_name: kalpeshk2011/rankgen-t5-base-all
+tokenizer_name: google/t5-v1_1-base
+learning_rate: 6e-6
+gradient_checkpointing: false
+fp16: true
+gradient_accumulation_steps: 16
+per_device_train_batch_size: 2
+warmup_steps: 600
+freeze_layer: 20
+eval_steps: 200
+save_steps: 500
+max_length: 400
+num_train_epochs: 2
+datasets:
+ - webgpt
+ - hfsummary
diff --git a/model/reward/instructor/configs/rankgen-t5-base.yml b/model/reward/instructor/configs/rankgen-t5-base.yml
new file mode 100644
index 00000000..bcb4d613
--- /dev/null
+++ b/model/reward/instructor/configs/rankgen-t5-base.yml
@@ -0,0 +1,19 @@
+model_name: kalpeshk2011/rankgen-t5-base-all
+# model_name: kalpeshk2011/rankgen-t5-xl-all
+# model_name: kalpeshk2011/rankgen-t5-xl-pg19
+# model_name: kalpeshk2011/rankgen-t5-large-all
+tokenizer_name: google/t5-v1_1-base
+learning_rate: 6e-6
+gradient_checkpointing: false
+fp16: false
+gradient_accumulation_steps: 16
+per_device_train_batch_size: 2
+warmup_steps: 600
+freeze_layer: 20
+eval_steps: 200
+save_steps: 500
+max_length: 400
+num_train_epochs: 2
+datasets:
+ - webgpt
+ - hfsummary
diff --git a/model/reward/instructor/models.py b/model/reward/instructor/models.py
new file mode 100644
index 00000000..c1891ed2
--- /dev/null
+++ b/model/reward/instructor/models.py
@@ -0,0 +1,27 @@
+import torch
+from transformers import AutoModel
+
+
+class RankGenModel(torch.nn.Module):
+ def __init__(self, model_name):
+ super().__init__()
+ self.rankgen_hf_hub = model_name
+ assert model_name in [
+ "kalpeshk2011/rankgen-t5-xl-all",
+ "kalpeshk2011/rankgen-t5-xl-pg19",
+ "kalpeshk2011/rankgen-t5-base-all",
+ "kalpeshk2011/rankgen-t5-large-all",
+ ]
+ self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
+
+ def forward(self, prefixes, suffixes):
+ # print(list(self.model.parameters()))
+ # raise Exception("stop")
+ embedded_prefixes = self.model(**prefixes)
+ embedded_suffixes = self.model(**suffixes)
+ # take dot product of each row independently
+ dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
+
+ # print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
+ # raise Exception("stop")
+ return dot_products
diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py
index a5c4b4fd..5e7da948 100644
--- a/model/reward/instructor/rank_datasets.py
+++ b/model/reward/instructor/rank_datasets.py
@@ -22,11 +22,41 @@ from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
+import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
+@dataclass
+class RankGenCollator:
+ tokenizer: PreTrainedTokenizerBase
+ padding: Union[bool, str, PaddingStrategy] = True
+ max_length: Optional[int] = None
+ max_examples: Optional[int] = None
+
+ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
+ prefixes = []
+ better_answers = []
+ worse_answers = []
+ for question, pairs in batch:
+ for (pos, neg) in pairs:
+ prefixes.append("pre " + question)
+ better_answers.append("suffi " + pos)
+ worse_answers.append("suffi " + neg)
+
+ tokenized_prefixes = self.tokenizer(
+ prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
+ )
+ tokenized_pos = self.tokenizer(
+ better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
+ )
+ tokenized_neg = self.tokenizer(
+ worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
+ )
+ return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
+
+
@dataclass
class DataCollatorForPairRank:
"""
diff --git a/model/reward/instructor/requirements.txt b/model/reward/instructor/requirements.txt
index e225a2ca..ca3935e4 100644
--- a/model/reward/instructor/requirements.txt
+++ b/model/reward/instructor/requirements.txt
@@ -1,6 +1,7 @@
datasets==2.8.0
evaluate==0.4.0
scikit-learn==1.2.0
-torch==1.12.1+cu116
+sentencepiece==0.1.97
+torch>=1.12.1
transformers==4.25.1
wandb==0.13.7
diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py
index b7eb8731..f9266d70 100644
--- a/model/reward/instructor/trainer.py
+++ b/model/reward/instructor/trainer.py
@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
-from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
+from models import RankGenModel
+from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
@@ -46,14 +47,16 @@ class RankLoss(nn.Module):
self.log_sigmoid = nn.LogSigmoid()
def forward(self, pos, neg):
- return -self.log_sigmoid(pos - neg + self.eps).mean()
+ loss = -self.log_sigmoid(pos - neg + self.eps).mean()
+ return loss
class RankTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
- args: TrainingArguments = None,
+ model_name: str = None,
+ args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
@@ -79,15 +82,25 @@ class RankTrainer(Trainer):
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
+ self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
# forward pass
- outputs = model(**inputs)
- logits = outputs.get("logits").view(-1, 2)
- if self.loss_function == "rank":
- loss = self.loss_fct(logits[:, 0], logits[:, 1])
+ if "rankgen" in self.model_name:
+ positive_outputs = model(inputs["prefix"], inputs["positive"])
+ negative_outputs = model(inputs["prefix"], inputs["negative"])
+ if self.loss_function == "rank":
+ loss = self.loss_fct(positive_outputs, negative_outputs)
+ else:
+ raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
+ outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
- loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
+ outputs = model(**inputs)
+ logits = outputs.get("logits").view(-1, 2)
+ if self.loss_function == "rank":
+ loss = self.loss_fct(logits[:, 0], logits[:, 1])
+ else:
+ loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
return (loss, outputs) if return_outputs else loss
@@ -109,24 +122,37 @@ class RankTrainer(Trainer):
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ with torch.inference_mode():
+ if "rankgen" in self.model_name:
+ inputs = self._prepare_inputs(inputs)
+ positive_outputs = model(inputs["prefix"], inputs["positive"])
+ negative_outputs = model(inputs["prefix"], inputs["negative"])
+ if self.loss_function == "rank":
+ loss = self.loss_fct(positive_outputs, negative_outputs)
+ else:
+ raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
+ outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
+ return (loss, outputs, None)
+ else:
+ # compute loss on predict data
+ loss, logits = self._compute_loss(model, inputs)
- with torch.no_grad():
- # compute loss on predict data
- loss, logits = self._compute_loss(model, inputs)
+ loss = loss.mean().detach()
+ labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
+ if self.args.prediction_loss_only:
+ return (loss, None, None)
- loss = loss.mean().detach()
- labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
- if self.args.prediction_loss_only:
- return (loss, None, None)
-
- return (loss, logits, labels)
+ return (loss, logits, labels)
if __name__ == "__main__":
training_conf = argument_parsing(parser)
model_name = training_conf["model_name"]
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
+ if "rankgen-t5" in model_name:
+ model = RankGenModel(model_name)
+ else:
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "freeze_layer" in training_conf:
num_layer = training_conf["freeze_layer"]
model = freeze_top_n_layers(model, num_layer)
@@ -134,7 +160,6 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
- tokenizer = get_tokenizer(model_name)
args = CustomTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
@@ -142,7 +167,7 @@ if __name__ == "__main__":
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
- fp16=True,
+ fp16=training_conf["fp16"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
@@ -154,7 +179,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
- report_to="wandb",
+ report_to="local",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -169,17 +194,23 @@ if __name__ == "__main__":
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
- collate_fn = DataCollatorForPairRank(
- tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
- )
+
+ tokenizer = get_tokenizer(training_conf["tokenizer_name"])
+
+ if "rankgen" in model_name:
+ collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
+ else:
+ collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
trainer = RankTrainer(
- model,
- args,
+ model=model,
+ model_name=model_name,
+ args=args,
train_dataset=train,
eval_dataset=eval,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
+ # trainer.evaluate()
trainer.train()
diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py
index 6c777dea..fe52c2ef 100644
--- a/model/reward/instructor/utils.py
+++ b/model/reward/instructor/utils.py
@@ -3,7 +3,7 @@ import re
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
@@ -25,7 +25,10 @@ def webgpt_return_format(row):
def get_tokenizer(tokenizer_name):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ if "t5" in tokenizer_name: # rankgen
+ tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "galactica" in tokenizer_name:
tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""})
@@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers):
def argument_parsing(parser):
+ args = parser.parse_args()
+ with open(args.config, "r", encoding="utf-8") as f:
+ training_conf = yaml.safe_load(f.read())
+
default_params = {
"num_train_epochs": 4,
"learning_rate": 3e-5,
@@ -78,10 +85,9 @@ def argument_parsing(parser):
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"datasets": ["webgpt"],
+ "fp16": True,
+ "tokenizer_name": training_conf["model_name"],
}
- args = parser.parse_args()
- with open(args.config, "r", encoding="utf-8") as f:
- training_conf = yaml.safe_load(f.read())
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md
index e223e1cd..014afa95 100644
--- a/model/supervised_finetuning/README.md
+++ b/model/supervised_finetuning/README.md
@@ -33,6 +33,6 @@ Experimental results in wandb
## TODOS
- decide on a model
-- add special token to declare prompt and reply. Do nto freeze the weights for
- these
- Merge utils etc with reward model
+- Casual Modelling for GPT-JT does not leverage the bidirectional mask for the
+ prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1)
diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml
index f7164002..29086395 100644
--- a/model/supervised_finetuning/configs/config.yaml
+++ b/model/supervised_finetuning/configs/config.yaml
@@ -32,6 +32,17 @@ galactica-125:
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
+gpt-jt:
+ learning_rate: 2e-6
+ model_name: togethercomputer/GPT-JT-6B-v1
+ weight_decay: 0.01
+ max_length: 1024
+ warmup_steps: 600
+ gradient_checkpointing: false
+ gradient_accumulation_steps: 2
+ per_device_train_batch_size: 4
+ per_device_eval_batch_size: 4
+
debug:
eval_steps: 20
eval_size: 100
diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py
index fcab8a56..7e3bdc79 100644
--- a/model/supervised_finetuning/custom_datasets/__init__.py
+++ b/model/supervised_finetuning/custom_datasets/__init__.py
@@ -2,6 +2,8 @@ from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
+QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""}
+
class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py
index 17fe1082..f9e1bb5e 100644
--- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py
+++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py
@@ -6,6 +6,8 @@ import torch
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
+from . import QA_SPECIAL_TOKENS
+
@dataclass
class DialogueDataCollator:
@@ -19,22 +21,21 @@ class DialogueDataCollator:
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
- # TODO add special tokens for question and answer here
- # additional_special_tokens = ['', '']
- prompt_tokens = ["Question: ", "Answer: "]
-
flatten_messages = []
label_masks = []
for messages in features:
assert len(messages) % 2 == 0, "Number of messages must be even"
messages = [
- (prompt_tokens[0] if i % 2 == 0 else "") + x + ((" " + prompt_tokens[1]) if i % 2 == 0 else "")
+ (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ + x
+ + (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
for i, x in enumerate(messages)
]
- # Add a way for the model to terminate generation, reinitialize prompter
- messages.append(prompt_tokens[0])
+ # Add a way for the model to terminate generation
+ # When we predict the start of a new expected question, we want to be able to stop generation
+ messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_messages.append(
self.tokenizer(
@@ -47,8 +48,10 @@ class DialogueDataCollator:
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
- # TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John.
- # MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3
+ # Label mask is true when predicting a token that is part of the answer, false otherwise.
+ # TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
+ # MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
+ # LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
# If no result in next, we are predicting the last termination token(s)
message_indices = list(
diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt
new file mode 100644
index 00000000..d579468f
--- /dev/null
+++ b/model/supervised_finetuning/requirements.txt
@@ -0,0 +1,6 @@
+datasets==2.8.0
+numpy==1.23.0
+PyYAML==6.0
+scikit_learn==1.2.0
+torch==1.13.1
+transformers==4.25.1
diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py
index b44890df..dc7b5934 100644
--- a/model/supervised_finetuning/trainer.py
+++ b/model/supervised_finetuning/trainer.py
@@ -67,6 +67,8 @@ class SFTTrainer(Trainer):
optimizers,
preprocess_logits_for_metrics,
)
+
+ # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(args.loss_function)
def fetch_scheduler(self):
@@ -112,7 +114,7 @@ class SFTTrainer(Trainer):
with torch.no_grad():
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
- labels[~labels_mask] = -1
+ labels[~labels_mask] = -100 # padding_index
loss = loss.mean().detach()
@@ -159,8 +161,8 @@ def argument_parsing(notebook=False, notebook_args=None):
if __name__ == "__main__":
training_conf = argument_parsing()
- model = get_model(training_conf)
tokenizer = get_tokenizer(training_conf)
+ model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py
index 4a451bed..a31f74d3 100644
--- a/model/supervised_finetuning/utils.py
+++ b/model/supervised_finetuning/utils.py
@@ -1,14 +1,14 @@
from pathlib import Path
import yaml
-from custom_datasets import get_one_dataset
+from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from transformers import AutoModelForCausalLM, AutoTokenizer
-SUPPORTED_MODELS = ["galactica"]
+SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated ..
def get_tokenizer(conf):
@@ -17,10 +17,19 @@ def get_tokenizer(conf):
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""})
+ additional_special_tokens = (
+ []
+ if "additional_special_tokens" not in tokenizer.special_tokens_map
+ else tokenizer.special_tokens_map["additional_special_tokens"]
+ )
+ additional_special_tokens = list(set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
+
+ tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
+
return tokenizer
-def get_model(conf):
+def get_model(conf, tokenizer):
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
raise ValueError(
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
@@ -29,6 +38,11 @@ def get_model(conf):
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
+ if len(tokenizer) != model.get_input_embeddings().num_embeddings:
+ assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
+
+ model.resize_token_embeddings(len(tokenizer))
+
if conf.freeze_layer:
model = freeze_top_n_layers(model, conf.freeze_layer)
diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py
index 83375d8f..e035d387 100644
--- a/oasst-shared/oasst_shared/schemas/protocol.py
+++ b/oasst-shared/oasst_shared/schemas/protocol.py
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import pydantic
from oasst_shared.exceptions import OasstErrorCode
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, conint, conlist, constr
class TaskRequestType(str, enum.Enum):
@@ -203,7 +203,7 @@ class TextReplyToMessage(Interaction):
type: Literal["text_reply_to_message"] = "text_reply_to_message"
message_id: str
user_message_id: str
- text: str
+ text: constr(min_length=1, strip_whitespace=True)
class MessageRating(Interaction):
@@ -211,7 +211,7 @@ class MessageRating(Interaction):
type: Literal["message_rating"] = "message_rating"
message_id: str
- rating: int
+ rating: conint(gt=0)
class MessageRanking(Interaction):
@@ -219,7 +219,7 @@ class MessageRanking(Interaction):
type: Literal["message_ranking"] = "message_ranking"
message_id: str
- ranking: list[int]
+ ranking: conlist(item_type=int, min_items=1)
AnyInteraction = Union[
diff --git a/scripts/frontend-development/run-contract-test.sh b/scripts/frontend-development/run-contract-test.sh
new file mode 100755
index 00000000..6bedc903
--- /dev/null
+++ b/scripts/frontend-development/run-contract-test.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
+
+# switch to website directory
+pushd "$parent_path/../../website"
+
+set -xe
+
+npm run cypress:run:contract
+
+popd
diff --git a/website/cypress.config.contract.js b/website/cypress.config.contract.js
new file mode 100644
index 00000000..f4461158
--- /dev/null
+++ b/website/cypress.config.contract.js
@@ -0,0 +1,9 @@
+import { defineConfig } from "cypress";
+
+export default defineConfig({
+ e2e: {
+ // No baseUrl here, because we don't need it for contract testing
+ baseUrl: null,
+ specPattern: "cypress/contract/*.cy.{ts,js}",
+ },
+});
diff --git a/website/cypress/contract/oasst_api_contract_tests.cy.ts b/website/cypress/contract/oasst_api_contract_tests.cy.ts
new file mode 100644
index 00000000..ff5bb156
--- /dev/null
+++ b/website/cypress/contract/oasst_api_contract_tests.cy.ts
@@ -0,0 +1,29 @@
+import OasstApiClient from "src/lib/oasst_api_client";
+
+describe("Contract test for Oasst API", function () {
+ // Assumes this is running the mock server.
+ const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
+
+ it("can fetch a task", async () => {
+ expect(
+ await oasstApiClient.fetchTask("random", {
+ sub: "test",
+ name: "test",
+ email: "test",
+ })
+ ).to.be.not.null;
+ });
+
+ it("can ack a task", async () => {
+ const task = await oasstApiClient.fetchTask("random", {
+ sub: "test",
+ name: "test",
+ email: "test",
+ });
+ expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
+ });
+
+ // TODO(#354): Add test for 204
+ // TODO(#354): Add test for parsing >=300, throwing an OasstError
+ // TODO(#354): Add test for parsing >=300, throwing a generic error
+});
diff --git a/website/next.config.js b/website/next.config.js
index 2c37ebe6..28da824f 100644
--- a/website/next.config.js
+++ b/website/next.config.js
@@ -2,6 +2,14 @@
const nextConfig = {
output: "standalone",
reactStrictMode: true,
+ images: {
+ remotePatterns: [
+ {
+ protocol: "https",
+ hostname: "**.discordapp.com",
+ },
+ ],
+ },
experimental: {
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
If this can be resolved, we should re-enable this.
diff --git a/website/package.json b/website/package.json
index c66e10ca..7d3b680e 100644
--- a/website/package.json
+++ b/website/package.json
@@ -12,6 +12,7 @@
"build-storybook": "build-storybook",
"cypress": "cypress open",
"cypress:run": "cypress run",
+ "cypress:run:contract": "cypress run --config-file ./cypress.config.contract.js",
"cypress:image-baseline": "cypress-image-diff -u",
"fix:lint": "eslint --fix src/ --ext .js,.jsx,.ts,.tsx",
"fix:format": "prettier --write ./src",
diff --git a/website/src/components/Dashboard/SideMenu.tsx b/website/src/components/Dashboard/SideMenu.tsx
index 30a45777..499117a2 100644
--- a/website/src/components/Dashboard/SideMenu.tsx
+++ b/website/src/components/Dashboard/SideMenu.tsx
@@ -1,6 +1,6 @@
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { useRouter } from "next/router";
-import { FiLayout, FiSun } from "react-icons/fi";
+import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
import { colors } from "styles/Theme/colors";
export function SideMenu() {
@@ -13,6 +13,12 @@ export function SideMenu() {
desc: "Dashboard Home",
icon: FiLayout,
},
+ {
+ label: "Messages",
+ pathname: "/messages",
+ desc: "Messages Dashboard",
+ icon: FiMessageSquare,
+ },
// {
// label: "Leaderboard",
// pathname: "#",
diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx
index 6b17a079..50f707a6 100644
--- a/website/src/components/Dashboard/TaskOption.tsx
+++ b/website/src/components/Dashboard/TaskOption.tsx
@@ -2,11 +2,17 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } fro
import Link from "next/link";
const crTasks = [
+ {
+ label: "Create Initial Prompts",
+ desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
+ type: "create",
+ pathname: "/create/initial_prompt",
+ },
{
label: "Reply as User",
desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.",
type: "create",
- pathname: "/create/assistant_reply",
+ pathname: "/create/user_reply",
},
{
label: "Reply as Assistant",
diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx
index c8fc17ce..bc245bd2 100644
--- a/website/src/components/FlaggableElement.tsx
+++ b/website/src/components/FlaggableElement.tsx
@@ -15,6 +15,7 @@ import {
SliderThumb,
SliderTrack,
Spacer,
+ Tooltip,
useBoolean,
useId,
} from "@chakra-ui/react";
@@ -69,11 +70,15 @@ export const FlaggableElement = (props) => {
>
{props.children}
-
-
-
+
+
+);
+
export const noLayout = (page: React.ReactElement) => page;
diff --git a/website/src/components/Messages/MessageTable.tsx b/website/src/components/Messages/MessageTable.tsx
new file mode 100644
index 00000000..65f6efb1
--- /dev/null
+++ b/website/src/components/Messages/MessageTable.tsx
@@ -0,0 +1,12 @@
+import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
+import { MessageTableEntry } from "./MessageTableEntry";
+
+export function MessageTable({ messages }) {
+ return (
+ } spacing="4">
+ {messages.map((item, idx) => (
+
+ ))}
+
+ );
+}
diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx
new file mode 100644
index 00000000..2bc17201
--- /dev/null
+++ b/website/src/components/Messages/MessageTableEntry.tsx
@@ -0,0 +1,24 @@
+import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
+import { boolean } from "boolean";
+import NextLink from "next/link";
+import { FlaggableElement } from "../FlaggableElement";
+
+export function MessageTableEntry({ item, idx }) {
+ const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
+
+ return (
+
+
+
+
+
+ {item.text}
+
+
+
+
+ );
+}
diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts
new file mode 100644
index 00000000..45a0859e
--- /dev/null
+++ b/website/src/lib/oasst_api_client.ts
@@ -0,0 +1,64 @@
+import { JWT } from "next-auth/jwt";
+
+class OasstError {
+ message: string;
+ errorCode: number;
+ httpStatusCode: number;
+
+ constructor(message: string, errorCode: number, httpStatusCode: number) {
+ this.message = message;
+ this.errorCode = errorCode;
+ this.httpStatusCode = httpStatusCode;
+ }
+}
+
+export default class OasstApiClient {
+ constructor(private readonly oasstApiUrl: string, private readonly oasstApiKey: string) {}
+
+ private async post(path: string, body: any): Promise {
+ const resp = await fetch(`${this.oasstApiUrl}${path}`, {
+ method: "POST",
+ headers: {
+ "X-API-Key": this.oasstApiKey,
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify(body),
+ });
+
+ if (resp.status == 204) {
+ return null;
+ }
+
+ if (resp.status >= 300) {
+ const errorText = await resp.text();
+ try {
+ const error = JSON.parse(errorText);
+ throw new OasstError(error.message, error.error_code, resp.status);
+ } catch (e) {
+ throw new OasstError(errorText, 0, resp.status);
+ }
+ }
+
+ return await resp.json();
+ }
+
+ // TODO return a strongly typed Task?
+ // This method is used to store a task in RegisteredTask.task.
+ // This is a raw Json type, so we can't use it to strongly type the task.
+ async fetchTask(taskType: string, userToken: JWT): Promise {
+ return this.post("/api/v1/tasks/", {
+ type: taskType,
+ user: {
+ id: userToken.sub,
+ display_name: userToken.name || userToken.email,
+ auth_method: "local",
+ },
+ });
+ }
+
+ async ackTask(taskId: string, messageId: string): Promise {
+ return this.post(`/api/v1/tasks/${taskId}/ack`, {
+ message_id: messageId,
+ });
+ }
+}
diff --git a/website/src/pages/account/edit.tsx b/website/src/pages/account/edit.tsx
index f695fce7..497e8238 100644
--- a/website/src/pages/account/edit.tsx
+++ b/website/src/pages/account/edit.tsx
@@ -36,22 +36,24 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
-
-
{session.user.name || "No username"}
-
-
+
+
+
{session.user.name || "No username"}
+
+
+
>
);
}
diff --git a/website/src/pages/account/index.tsx b/website/src/pages/account/index.tsx
index 9f8dc4a3..d26fc842 100644
--- a/website/src/pages/account/index.tsx
+++ b/website/src/pages/account/index.tsx
@@ -19,13 +19,15 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
-
-
{session.user.name || "No username"}
-
-
{session.user.email}
-
+
+
+
{session.user.name || "No username"}
+
+
{session.user.email}
+
+
>
);
}
diff --git a/website/src/pages/api/messages/index.tsx b/website/src/pages/api/messages/index.tsx
new file mode 100644
index 00000000..3c8d1c17
--- /dev/null
+++ b/website/src/pages/api/messages/index.tsx
@@ -0,0 +1,24 @@
+import { getToken } from "next-auth/jwt";
+
+const handler = async (req, res) => {
+ const token = await getToken({ req });
+
+ // Return nothing if the user isn't registered.
+ if (!token) {
+ res.status(401).end();
+ return;
+ }
+
+ const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
+ method: "GET",
+ headers: {
+ "X-API-Key": process.env.FASTAPI_KEY,
+ },
+ });
+ const messages = await messagesRes.json();
+
+ // Send recieved messages to the client.
+ res.status(200).json(messages);
+};
+
+export default handler;
diff --git a/website/src/pages/api/messages/user.tsx b/website/src/pages/api/messages/user.tsx
new file mode 100644
index 00000000..e3d22f3c
--- /dev/null
+++ b/website/src/pages/api/messages/user.tsx
@@ -0,0 +1,29 @@
+import { getToken } from "next-auth/jwt";
+
+const handler = async (req, res) => {
+ const token = await getToken({ req });
+
+ // Return nothing if the user isn't registered.
+ if (!token) {
+ res.status(401).end();
+ return;
+ }
+
+ //TODO: add params if needed
+ const params = new URLSearchParams({
+ username: token.sub,
+ });
+
+ const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
+ method: "GET",
+ headers: {
+ "X-API-Key": process.env.FASTAPI_KEY,
+ },
+ });
+ const messages = await messagesRes.json();
+
+ // Send recieved messages to the client.
+ res.status(200).json(messages);
+};
+
+export default handler;
diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts
index 50f0b4e2..bbe31bef 100644
--- a/website/src/pages/api/new_task/[task_type].ts
+++ b/website/src/pages/api/new_task/[task_type].ts
@@ -1,4 +1,5 @@
import { getToken } from "next-auth/jwt";
+import OasstApiClient from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
/**
@@ -20,25 +21,10 @@ const handler = async (req, res) => {
return;
}
+ const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
+
// Fetch the new task.
- //
- // This needs to be refactored into an easier to use library.
- const taskRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/`, {
- method: "POST",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- body: JSON.stringify({
- type: task_type,
- user: {
- id: token.sub,
- display_name: token.name || token.email,
- auth_method: "local",
- },
- }),
- });
- const task = await taskRes.json();
+ const task = await oasstApiClient.fetchTask(task_type, token);
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -53,16 +39,7 @@ const handler = async (req, res) => {
});
// Update the backend with our Task ID
- await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
- method: "POST",
- headers: {
- "X-API-Key": process.env.FASTAPI_KEY,
- "Content-Type": "application/json",
- },
- body: JSON.stringify({
- message_id: registeredTask.id,
- }),
- });
+ await oasstApiClient.ackTask(task.id, registeredTask.id);
// Send the results to the client.
res.status(200).json(registeredTask);
diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts
index 9582040b..2d371354 100644
--- a/website/src/pages/api/update_task.ts
+++ b/website/src/pages/api/update_task.ts
@@ -36,6 +36,7 @@ const handler = async (req, res) => {
// Send the interaction to the Task Backend. This automatically fetches the
// next task in the sequence (or the done task).
+ // TODO(#353): Move this into OasstApiClient.
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/interaction`, {
method: "POST",
headers: {
diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx
index ceac45be..2f7ca748 100644
--- a/website/src/pages/create/assistant_reply.tsx
+++ b/website/src/pages/create/assistant_reply.tsx
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
-import { useRef, useState } from "react";
+import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const AssistantReply = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx
index 5aecf98c..a4aed9c3 100644
--- a/website/src/pages/create/initial_prompt.tsx
+++ b/website/src/pages/create/initial_prompt.tsx
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
-import { useRef, useState } from "react";
+import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
@@ -27,6 +27,12 @@ const InitialPrompt = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx
index b9022e86..40409189 100644
--- a/website/src/pages/create/user_reply.tsx
+++ b/website/src/pages/create/user_reply.tsx
@@ -1,6 +1,6 @@
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
-import { useRef, useState } from "react";
+import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const UserReply = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx
index dfc5cb03..8b1f6861 100644
--- a/website/src/pages/dashboard.tsx
+++ b/website/src/pages/dashboard.tsx
@@ -1,6 +1,7 @@
import { Box, useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { Header } from "src/components/Header";
+
+import { getDashboardLayout } from "src/components/Layout";
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
import { colors } from "styles/Theme/colors";
@@ -27,11 +28,6 @@ const Dashboard = () => {
);
};
-Dashboard.getLayout = (page) => (
-
-
- {page}
-
-);
+Dashboard.getLayout = (page) => getDashboardLayout(page);
export default Dashboard;
diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx
index e8558c00..0ed69b09 100644
--- a/website/src/pages/evaluate/rank_assistant_replies.tsx
+++ b/website/src/pages/evaluate/rank_assistant_replies.tsx
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useState } from "react";
+import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
@@ -26,6 +26,12 @@ const RankAssistantReplies = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx
index 48a67e90..a9d590ac 100644
--- a/website/src/pages/evaluate/rank_initial_prompts.tsx
+++ b/website/src/pages/evaluate/rank_initial_prompts.tsx
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useState } from "react";
+import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
@@ -32,6 +32,12 @@ const RankInitialPrompts = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const submitResponse = (task) => {
trigger({
id: task.id,
diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx
index 3f806a8a..9a0577cb 100644
--- a/website/src/pages/evaluate/rank_user_replies.tsx
+++ b/website/src/pages/evaluate/rank_user_replies.tsx
@@ -1,6 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useState } from "react";
+import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
@@ -26,6 +26,12 @@ const RankUserReplies = () => {
},
});
+ useEffect(() => {
+ if (tasks.length == 0) {
+ mutate();
+ }
+ }, [tasks]);
+
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx
new file mode 100644
index 00000000..39430caf
--- /dev/null
+++ b/website/src/pages/messages/index.tsx
@@ -0,0 +1,83 @@
+import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
+import Head from "next/head";
+import { useState } from "react";
+import useSWRImmutable from "swr/immutable";
+
+import fetcher from "src/lib/fetcher";
+import { SideMenu } from "src/components/Dashboard";
+import { MessageTable } from "src/components/Messages/MessageTable";
+import { getDashboardLayout } from "src/components/Layout";
+import { colors } from "styles/Theme/colors";
+
+const MessagesDashboard = () => {
+ const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg);
+ const boxBgColor = useColorModeValue("white", "gray.700");
+ const boxAccentColor = useColorModeValue("gray.200", "gray.900");
+
+ const [messages, setMessages] = useState([]);
+ const [userMessages, setUserMessages] = useState([]);
+
+ const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, {
+ onSuccess: (data) => {
+ setMessages(data);
+ },
+ });
+
+ const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
+ onSuccess: (data) => {
+ setUserMessages(data);
+ },
+ });
+
+ return (
+ <>
+
+ Messages - Open Assistant
+
+
+
+
+
+
+
+
+
+
+
+ Most recent messages
+
+
+ {isLoadingAll ? : }
+
+
+
+
+ Your most recent messages
+
+
+ {isLoadingUser ? : }
+
+
+
+
+
+
+ >
+ );
+};
+
+MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
+
+export default MessagesDashboard;
diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx
index dcb3bc19..164fab93 100644
--- a/website/src/pages/privacy-policy.tsx
+++ b/website/src/pages/privacy-policy.tsx
@@ -14,7 +14,7 @@ const PrivacyPolicy = () => {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
-
+
Privacy Policy
diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx
index d97c8d34..ce60a20e 100644
--- a/website/src/pages/terms-of-service.tsx
+++ b/website/src/pages/terms-of-service.tsx
@@ -14,7 +14,7 @@ const TermsOfService = () => {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
-
+
Terms Of Service