mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
Merge branch 'LAION-AI:main' into main
This commit is contained in:
@@ -15,6 +15,9 @@ jobs:
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16
|
||||
|
||||
- run: cd oasst-shared && pip install -e .
|
||||
|
||||
@@ -22,9 +25,14 @@ jobs:
|
||||
|
||||
- run: cd backend && pip install -r requirements.txt
|
||||
|
||||
- run: cd website && npm install
|
||||
|
||||
- run: ./scripts/backend-development/start-mock-server.sh
|
||||
|
||||
- name: Run contract tests
|
||||
- name: Run Python OasstApiClient contract tests
|
||||
run: ./scripts/oasst-shared-development/test.sh
|
||||
|
||||
- name: Run JavaScript OasstApiClient contract tests
|
||||
run: ./scripts/frontend-development/run-contract-test.sh
|
||||
|
||||
- run: ./scripts/backend-development/stop-mock-server.sh
|
||||
|
||||
+110
@@ -0,0 +1,110 @@
|
||||
# I’m in! Now what?
|
||||
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
|
||||
this is for work coordination.
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
|
||||
a dedicated channel and is more public.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
|
||||
dedicated, but not as active, channel.
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
### Taking on Tasks
|
||||
|
||||
We have a growing task list
|
||||
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
|
||||
that appeals to you and make a comment that you'd like to work on it. Include in
|
||||
your comment a brief description of how you'll solve the problem and if there
|
||||
are any open questions you want to discuss. Once a project coordinator has
|
||||
assigned the issue to you, start working on it.
|
||||
|
||||
If the issue is currently unclear but you are interested, please post in Discord
|
||||
and someone can help clarify the issue with more detail.
|
||||
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
|
||||
the system architecture, and other documentation.
|
||||
|
||||
### Submitting Work
|
||||
|
||||
We're all working on different parts of Open Assistant together. To make
|
||||
contributions smoothly we recommend the following:
|
||||
|
||||
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
|
||||
and clone it to your local machine. (Read more
|
||||
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
|
||||
1. Before working on any changes, try to
|
||||
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
|
||||
to keep it up-to-date with the upstream repository.
|
||||
1. Work on a small focused change that only touches on a few files.
|
||||
1. Run `pre-commit` and make sure all files have formatting fixed. This
|
||||
simplifies life for reviewers.
|
||||
1. Package up a small bit of work that solves part of the problem
|
||||
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
|
||||
and
|
||||
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved any conflicts, finish the review and merge into `main`.
|
||||
1. Merge in your change and move onto a new issue or the second step of your
|
||||
current issue.
|
||||
|
||||
Additionally, if someone is working on an issue that interests you, ask if they
|
||||
need help on it or would like suggestions on how to approach the issue. If so,
|
||||
share wildly. If they seem to have a good handle on it, let them work on their
|
||||
solution until a challenge comes up.
|
||||
|
||||
### When does a review finish
|
||||
|
||||
A review finishes when all blocking comments are addressed and at least one
|
||||
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
|
||||
comments either by making the request change, explaining why it's not being
|
||||
addressed now, or filing an issue to handle it later.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
Work is organized in the
|
||||
[project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs.
|
||||
Meaning we'd be happy for anyone to do these tasks.**
|
||||
|
||||
If you want to work on something, assign yourself to it or write a comment that
|
||||
you want to work on it and what you plan to do.
|
||||
|
||||
- To get started with development, if you want to work on the backend, have a
|
||||
look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at
|
||||
`scripts/frontend-development/README.md` to make a backend available.
|
||||
|
||||
There is also a minimal implementation of a frontend in the `text-frontend`
|
||||
folder.
|
||||
|
||||
We are using Python 3.10 for the backend.
|
||||
|
||||
Check out the
|
||||
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
|
||||
### Website
|
||||
|
||||
The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
@@ -1,7 +1,24 @@
|
||||
# Open-Assistant
|
||||
<h1 align="center">
|
||||
<span>Open-Assistant</span>
|
||||
<img width="auto" height="50px" src="https://github.com/LAION-AI/Open-Assistant/blob/main/assets/logo_crop.png"/>
|
||||
</h1>
|
||||
|
||||
Open Assistant is a project meant to give everyone access to a great chat based
|
||||
large language model.
|
||||
# Table of Contents
|
||||
|
||||
- [What is Open Assistant?](#what-is-open-assistant)
|
||||
- [Do you want to try it out?](#do-you-want-to-try-it-out)
|
||||
- [The Plan](#the-plan)
|
||||
- [The Vision](#the-vision)
|
||||
- [How can you help?](#how-can-you-help)
|
||||
- [I’m in! How do I contribute?](CONTRIBUTING.md)
|
||||
|
||||
---
|
||||
|
||||
## What is Open Assistant?
|
||||
|
||||
<p align="center">
|
||||
Open Assistant is a project meant to give everyone access to a great chat based large language model.
|
||||
</p>
|
||||
|
||||
We believe that by doing this we will create a revolution in innovation in
|
||||
language. In the same way that stable-diffusion helped the world make art and
|
||||
@@ -14,7 +31,7 @@ If you are interested in taking a look at the current state of the project, you
|
||||
can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
website, backend, and associated dependent services.
|
||||
|
||||
To start the demo, run this in the root directory of the repository:
|
||||
##### To start the demo, run this in the root directory of the repository:
|
||||
|
||||
```sh
|
||||
docker compose up --build
|
||||
@@ -23,22 +40,21 @@ docker compose up --build
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
|
||||
interact with the website.
|
||||
|
||||
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
|
||||
the magic email login link.
|
||||
> **Note:** When logging in via email, navigate to `http://localhost:1080` to
|
||||
> get the magic email login link.
|
||||
|
||||
**Note:** If you would like to run this in a standardized development
|
||||
environment (a
|
||||
["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
|
||||
using
|
||||
[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
|
||||
or in a web browser using
|
||||
[GitHub Codespaces](https://github.com/features/codespaces), you can use the
|
||||
provided [`.devcontainer`](.devcontainer/) folder.
|
||||
> **Note:** If you would like to run this in a standardized development
|
||||
> environment (a
|
||||
> ["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
|
||||
> using
|
||||
> [vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
|
||||
> or in a web browser using
|
||||
> [GitHub Codespaces](https://github.com/features/codespaces), you can use the
|
||||
> provided [`.devcontainer`](.devcontainer/) folder.
|
||||
|
||||
## The Plan
|
||||
|
||||
We want to get to an initial MVP as fast as possible, by following the 3-steps
|
||||
outlined in the InstructGPT paper.
|
||||
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
|
||||
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples
|
||||
(prompt + response), goal >50k. We design a crowdsourced process to collect
|
||||
@@ -80,113 +96,4 @@ All open source projects begin with people like you. Open source is the belief
|
||||
that if we collaborate we can together gift our knowledge and technology to the
|
||||
world for the benefit of humanity.
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
|
||||
this is for work coordination.
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
|
||||
a dedicated channel and is more public.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
|
||||
dedicated, but not as active, channel.
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
### Taking on Tasks
|
||||
|
||||
We have a growing task list
|
||||
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
|
||||
that appeals to you and make a comment that you'd like to work on it. Include in
|
||||
your comment a brief description of how you'll solve the problem and if there
|
||||
are any open questions you want to discuss. Once a project coordinator has
|
||||
assigned the issue to you, start working on it.
|
||||
|
||||
If the issue is currently unclear but you are interested, please post in Discord
|
||||
and someone can help clarify the issue with more detail.
|
||||
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
|
||||
the system architecture, and other documentation.
|
||||
|
||||
### Submitting Work
|
||||
|
||||
We're all working on different parts of Open Assistant together. To make
|
||||
contributions smoothly we recommend the following:
|
||||
|
||||
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
|
||||
and clone it to your local machine. (Read more
|
||||
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
|
||||
1. Before working on any changes, try to
|
||||
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
|
||||
to keep it up-to-date with the upstream repository.
|
||||
1. Work on a small focused change that only touches on a few files.
|
||||
1. Run `pre-commit` and make sure all files have formatting fixed. This
|
||||
simplifies life for reviewers.
|
||||
1. Package up a small bit of work that solves part of the problem
|
||||
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
|
||||
and
|
||||
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved any conflicts, finish the review and merge into `main`.
|
||||
1. Merge in your change and move onto a new issue or the second step of your
|
||||
current issue.
|
||||
|
||||
Additionally, if someone is working on an issue that interests you, ask if they
|
||||
need help on it or would like suggestions on how to approach the issue. If so,
|
||||
share wildly. If they seem to have a good handle on it, let them work on their
|
||||
solution until a challenge comes up.
|
||||
|
||||
### When does a review finish
|
||||
|
||||
A review finishes when all blocking comments are addressed and at least one
|
||||
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
|
||||
comments either by making the request change, explaining why it's not being
|
||||
addressed now, or filing an issue to handle it later.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
Work is organized in the
|
||||
[project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs.
|
||||
Meaning we'd be happy for anyone to do these tasks.**
|
||||
|
||||
If you want to work on something, assign yourself to it or write a comment that
|
||||
you want to work on it and what you plan to do.
|
||||
|
||||
- To get started with development, if you want to work on the backend, have a
|
||||
look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at
|
||||
`scripts/frontend-development/README.md` to make a backend available.
|
||||
|
||||
There is also a minimal implementation of a frontend in the `text-frontend`
|
||||
folder.
|
||||
|
||||
We are using Python 3.10 for the backend.
|
||||
|
||||
Check out the
|
||||
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
|
||||
### Website
|
||||
|
||||
The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
Check out our [contributing guide](CONTRIBUTING.md) to get started.
|
||||
|
||||
@@ -82,6 +82,7 @@
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
+4
-2
@@ -11,10 +11,12 @@ Example contents of a `.env` file for the backend:
|
||||
```
|
||||
DATABASE_URI="postgresql://<username>:<password>@<host>/<database_name>"
|
||||
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
```
|
||||
|
||||
## Running the REST Server locally for development
|
||||
|
||||
Have a look into the main `README.md` file for more information on how to set up
|
||||
the backend for development.
|
||||
the backend for development. Use the scripts within the
|
||||
scripts/backend-development folder to run the BE API locally.
|
||||
|
||||
@@ -36,3 +36,4 @@ environments:
|
||||
secrets:
|
||||
# Note: URI, not URL.
|
||||
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
|
||||
REDIS_HOST: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/REDIS_HOST
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
model_name: kalpeshk2011/rankgen-t5-base-all
|
||||
tokenizer_name: google/t5-v1_1-base
|
||||
learning_rate: 6e-6
|
||||
gradient_checkpointing: false
|
||||
fp16: true
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
freeze_layer: 20
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,19 @@
|
||||
model_name: kalpeshk2011/rankgen-t5-base-all
|
||||
# model_name: kalpeshk2011/rankgen-t5-xl-all
|
||||
# model_name: kalpeshk2011/rankgen-t5-xl-pg19
|
||||
# model_name: kalpeshk2011/rankgen-t5-large-all
|
||||
tokenizer_name: google/t5-v1_1-base
|
||||
learning_rate: 6e-6
|
||||
gradient_checkpointing: false
|
||||
fp16: false
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
freeze_layer: 20
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
class RankGenModel(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
self.rankgen_hf_hub = model_name
|
||||
assert model_name in [
|
||||
"kalpeshk2011/rankgen-t5-xl-all",
|
||||
"kalpeshk2011/rankgen-t5-xl-pg19",
|
||||
"kalpeshk2011/rankgen-t5-base-all",
|
||||
"kalpeshk2011/rankgen-t5-large-all",
|
||||
]
|
||||
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
|
||||
|
||||
def forward(self, prefixes, suffixes):
|
||||
# print(list(self.model.parameters()))
|
||||
# raise Exception("stop")
|
||||
embedded_prefixes = self.model(**prefixes)
|
||||
embedded_suffixes = self.model(**suffixes)
|
||||
# take dot product of each row independently
|
||||
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
|
||||
|
||||
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
|
||||
# raise Exception("stop")
|
||||
return dot_products
|
||||
@@ -22,11 +22,41 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankGenCollator:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
max_examples: Optional[int] = None
|
||||
|
||||
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
|
||||
prefixes = []
|
||||
better_answers = []
|
||||
worse_answers = []
|
||||
for question, pairs in batch:
|
||||
for (pos, neg) in pairs:
|
||||
prefixes.append("pre " + question)
|
||||
better_answers.append("suffi " + pos)
|
||||
worse_answers.append("suffi " + neg)
|
||||
|
||||
tokenized_prefixes = self.tokenizer(
|
||||
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
||||
)
|
||||
tokenized_pos = self.tokenizer(
|
||||
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
||||
)
|
||||
tokenized_neg = self.tokenizer(
|
||||
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
|
||||
)
|
||||
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPairRank:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
datasets==2.8.0
|
||||
evaluate==0.4.0
|
||||
scikit-learn==1.2.0
|
||||
torch==1.12.1+cu116
|
||||
sentencepiece==0.1.97
|
||||
torch>=1.12.1
|
||||
transformers==4.25.1
|
||||
wandb==0.13.7
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from models import RankGenModel
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
@@ -46,14 +47,16 @@ class RankLoss(nn.Module):
|
||||
self.log_sigmoid = nn.LogSigmoid()
|
||||
|
||||
def forward(self, pos, neg):
|
||||
return -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
loss = -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class RankTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
model_name: str = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
@@ -79,15 +82,25 @@ class RankTrainer(Trainer):
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
self.model_name = model_name
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
if "rankgen" in self.model_name:
|
||||
positive_outputs = model(inputs["prefix"], inputs["positive"])
|
||||
negative_outputs = model(inputs["prefix"], inputs["negative"])
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(positive_outputs, negative_outputs)
|
||||
else:
|
||||
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
|
||||
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
@@ -109,24 +122,37 @@ class RankTrainer(Trainer):
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
with torch.inference_mode():
|
||||
if "rankgen" in self.model_name:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
positive_outputs = model(inputs["prefix"], inputs["positive"])
|
||||
negative_outputs = model(inputs["prefix"], inputs["negative"])
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(positive_outputs, negative_outputs)
|
||||
else:
|
||||
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
|
||||
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
|
||||
return (loss, outputs, None)
|
||||
else:
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "rankgen-t5" in model_name:
|
||||
model = RankGenModel(model_name)
|
||||
else:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
@@ -134,7 +160,6 @@ if __name__ == "__main__":
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
@@ -142,7 +167,7 @@ if __name__ == "__main__":
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
fp16=training_conf["fp16"],
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
@@ -154,7 +179,7 @@ if __name__ == "__main__":
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
report_to="local",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
@@ -169,17 +194,23 @@ if __name__ == "__main__":
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
|
||||
else:
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
|
||||
assert len(evals) > 0
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
args,
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
args=args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
# trainer.evaluate()
|
||||
trainer.train()
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, T5Tokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
|
||||
@@ -25,7 +25,10 @@ def webgpt_return_format(row):
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "t5" in tokenizer_name: # rankgen
|
||||
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "galactica" in tokenizer_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
@@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers):
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
default_params = {
|
||||
"num_train_epochs": 4,
|
||||
"learning_rate": 3e-5,
|
||||
@@ -78,10 +85,9 @@ def argument_parsing(parser):
|
||||
"gradient_accumulation_steps": 8,
|
||||
"gradient_checkpointing": False,
|
||||
"datasets": ["webgpt"],
|
||||
"fp16": True,
|
||||
"tokenizer_name": training_conf["model_name"],
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
params = {**default_params, **training_conf}
|
||||
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
|
||||
|
||||
@@ -33,6 +33,6 @@ Experimental results in wandb
|
||||
## TODOS
|
||||
|
||||
- decide on a model
|
||||
- add special token to declare prompt and reply. Do nto freeze the weights for
|
||||
these
|
||||
- Merge utils etc with reward model
|
||||
- Casual Modelling for GPT-JT does not leverage the bidirectional mask for the
|
||||
prompt? (https://huggingface.co/togethercomputer/GPT-JT-6B-v1)
|
||||
|
||||
@@ -32,6 +32,17 @@ galactica-125:
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
gpt-jt:
|
||||
learning_rate: 2e-6
|
||||
model_name: togethercomputer/GPT-JT-6B-v1
|
||||
weight_decay: 0.01
|
||||
max_length: 1024
|
||||
warmup_steps: 600
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 2
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
|
||||
@@ -2,6 +2,8 @@ from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
|
||||
|
||||
class SquadV2Dataset(Dataset):
|
||||
def __init__(self, cache_dir, split):
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
from . import QA_SPECIAL_TOKENS
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueDataCollator:
|
||||
@@ -19,22 +21,21 @@ class DialogueDataCollator:
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
# TODO add special tokens for question and answer here
|
||||
# additional_special_tokens = ['<question>', '<answer>']
|
||||
prompt_tokens = ["Question: ", "Answer: "]
|
||||
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
for messages in features:
|
||||
assert len(messages) % 2 == 0, "Number of messages must be even"
|
||||
messages = [
|
||||
(prompt_tokens[0] if i % 2 == 0 else "") + x + ((" " + prompt_tokens[1]) if i % 2 == 0 else "")
|
||||
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
|
||||
+ x
|
||||
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
|
||||
for i, x in enumerate(messages)
|
||||
]
|
||||
|
||||
# Add a way for the model to terminate generation, reinitialize prompter
|
||||
messages.append(prompt_tokens[0])
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(QA_SPECIAL_TOKENS["Question"])
|
||||
|
||||
flatten_messages.append(
|
||||
self.tokenizer(
|
||||
@@ -47,8 +48,10 @@ class DialogueDataCollator:
|
||||
|
||||
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
|
||||
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
|
||||
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John.
|
||||
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3
|
||||
# Label mask is true when predicting a token that is part of the answer, false otherwise.
|
||||
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
|
||||
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
|
||||
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
|
||||
|
||||
# If no result in next, we are predicting the last termination token(s)
|
||||
message_indices = list(
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
datasets==2.8.0
|
||||
numpy==1.23.0
|
||||
PyYAML==6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
transformers==4.25.1
|
||||
@@ -67,6 +67,8 @@ class SFTTrainer(Trainer):
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(args.loss_function)
|
||||
|
||||
def fetch_scheduler(self):
|
||||
@@ -112,7 +114,7 @@ class SFTTrainer(Trainer):
|
||||
|
||||
with torch.no_grad():
|
||||
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
|
||||
labels[~labels_mask] = -1
|
||||
labels[~labels_mask] = -100 # padding_index
|
||||
|
||||
loss = loss.mean().detach()
|
||||
|
||||
@@ -159,8 +161,8 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing()
|
||||
|
||||
model = get_model(training_conf)
|
||||
tokenizer = get_tokenizer(training_conf)
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SUPPORTED_MODELS = ["galactica"]
|
||||
SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated ..
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
@@ -17,10 +17,19 @@ def get_tokenizer(conf):
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
if "additional_special_tokens" not in tokenizer.special_tokens_map
|
||||
else tokenizer.special_tokens_map["additional_special_tokens"]
|
||||
)
|
||||
additional_special_tokens = list(set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
|
||||
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(conf):
|
||||
def get_model(conf, tokenizer):
|
||||
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
|
||||
raise ValueError(
|
||||
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
|
||||
@@ -29,6 +38,11 @@ def get_model(conf):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
|
||||
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if conf.freeze_layer:
|
||||
model = freeze_top_n_layers(model, conf.freeze_layer)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from oasst_shared.exceptions import OasstErrorCode
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, conint, conlist, constr
|
||||
|
||||
|
||||
class TaskRequestType(str, enum.Enum):
|
||||
@@ -203,7 +203,7 @@ class TextReplyToMessage(Interaction):
|
||||
type: Literal["text_reply_to_message"] = "text_reply_to_message"
|
||||
message_id: str
|
||||
user_message_id: str
|
||||
text: str
|
||||
text: constr(min_length=1, strip_whitespace=True)
|
||||
|
||||
|
||||
class MessageRating(Interaction):
|
||||
@@ -211,7 +211,7 @@ class MessageRating(Interaction):
|
||||
|
||||
type: Literal["message_rating"] = "message_rating"
|
||||
message_id: str
|
||||
rating: int
|
||||
rating: conint(gt=0)
|
||||
|
||||
|
||||
class MessageRanking(Interaction):
|
||||
@@ -219,7 +219,7 @@ class MessageRanking(Interaction):
|
||||
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
message_id: str
|
||||
ranking: list[int]
|
||||
ranking: conlist(item_type=int, min_items=1)
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
|
||||
+11
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
|
||||
# switch to website directory
|
||||
pushd "$parent_path/../../website"
|
||||
|
||||
set -xe
|
||||
|
||||
npm run cypress:run:contract
|
||||
|
||||
popd
|
||||
@@ -0,0 +1,9 @@
|
||||
import { defineConfig } from "cypress";
|
||||
|
||||
export default defineConfig({
|
||||
e2e: {
|
||||
// No baseUrl here, because we don't need it for contract testing
|
||||
baseUrl: null,
|
||||
specPattern: "cypress/contract/*.cy.{ts,js}",
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,29 @@
|
||||
import OasstApiClient from "src/lib/oasst_api_client";
|
||||
|
||||
describe("Contract test for Oasst API", function () {
|
||||
// Assumes this is running the mock server.
|
||||
const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
|
||||
|
||||
it("can fetch a task", async () => {
|
||||
expect(
|
||||
await oasstApiClient.fetchTask("random", {
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
})
|
||||
).to.be.not.null;
|
||||
});
|
||||
|
||||
it("can ack a task", async () => {
|
||||
const task = await oasstApiClient.fetchTask("random", {
|
||||
sub: "test",
|
||||
name: "test",
|
||||
email: "test",
|
||||
});
|
||||
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
|
||||
});
|
||||
|
||||
// TODO(#354): Add test for 204
|
||||
// TODO(#354): Add test for parsing >=300, throwing an OasstError
|
||||
// TODO(#354): Add test for parsing >=300, throwing a generic error
|
||||
});
|
||||
@@ -2,6 +2,14 @@
|
||||
const nextConfig = {
|
||||
output: "standalone",
|
||||
reactStrictMode: true,
|
||||
images: {
|
||||
remotePatterns: [
|
||||
{
|
||||
protocol: "https",
|
||||
hostname: "**.discordapp.com",
|
||||
},
|
||||
],
|
||||
},
|
||||
experimental: {
|
||||
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
|
||||
If this can be resolved, we should re-enable this.
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
"build-storybook": "build-storybook",
|
||||
"cypress": "cypress open",
|
||||
"cypress:run": "cypress run",
|
||||
"cypress:run:contract": "cypress run --config-file ./cypress.config.contract.js",
|
||||
"cypress:image-baseline": "cypress-image-diff -u",
|
||||
"fix:lint": "eslint --fix src/ --ext .js,.jsx,.ts,.tsx",
|
||||
"fix:format": "prettier --write ./src",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
|
||||
import { useRouter } from "next/router";
|
||||
import { FiLayout, FiSun } from "react-icons/fi";
|
||||
import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export function SideMenu() {
|
||||
@@ -13,6 +13,12 @@ export function SideMenu() {
|
||||
desc: "Dashboard Home",
|
||||
icon: FiLayout,
|
||||
},
|
||||
{
|
||||
label: "Messages",
|
||||
pathname: "/messages",
|
||||
desc: "Messages Dashboard",
|
||||
icon: FiMessageSquare,
|
||||
},
|
||||
// {
|
||||
// label: "Leaderboard",
|
||||
// pathname: "#",
|
||||
|
||||
@@ -2,11 +2,17 @@ import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } fro
|
||||
import Link from "next/link";
|
||||
|
||||
const crTasks = [
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
type: "create",
|
||||
pathname: "/create/initial_prompt",
|
||||
},
|
||||
{
|
||||
label: "Reply as User",
|
||||
desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.",
|
||||
type: "create",
|
||||
pathname: "/create/assistant_reply",
|
||||
pathname: "/create/user_reply",
|
||||
},
|
||||
{
|
||||
label: "Reply as Assistant",
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
SliderThumb,
|
||||
SliderTrack,
|
||||
Spacer,
|
||||
Tooltip,
|
||||
useBoolean,
|
||||
useId,
|
||||
} from "@chakra-ui/react";
|
||||
@@ -69,11 +70,15 @@ export const FlaggableElement = (props) => {
|
||||
>
|
||||
<Grid templateColumns="1fr min-content" gap={2}>
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
<PopoverTrigger>
|
||||
<Button h="full">
|
||||
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<Tooltip hasArrow label="Report" bg="red.600">
|
||||
<div>
|
||||
<PopoverTrigger>
|
||||
<Button h="full">
|
||||
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</div>
|
||||
</Tooltip>
|
||||
</Grid>
|
||||
|
||||
<PopoverContent width="fit-content">
|
||||
|
||||
@@ -42,7 +42,7 @@ export function UserMenu() {
|
||||
className="flex items-center gap-4 p-1 lg:pr-6 rounded-full transition-colors duration-300"
|
||||
>
|
||||
<Image
|
||||
src="/images/temp-avatars/av5.jpg"
|
||||
src={session.user.image || "/images/temp-avatars/av1.jpg"}
|
||||
alt="Profile Picture"
|
||||
width="36"
|
||||
height="36"
|
||||
|
||||
@@ -25,4 +25,11 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
|
||||
</div>
|
||||
);
|
||||
|
||||
export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
{page}
|
||||
</div>
|
||||
);
|
||||
|
||||
export const noLayout = (page: React.ReactElement) => page;
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
|
||||
import { MessageTableEntry } from "./MessageTableEntry";
|
||||
|
||||
export function MessageTable({ messages }) {
|
||||
return (
|
||||
<Stack divider={<StackDivider />} spacing="4">
|
||||
{messages.map((item, idx) => (
|
||||
<MessageTableEntry item={item} idx={idx} key={item.id} />
|
||||
))}
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import NextLink from "next/link";
|
||||
import { FlaggableElement } from "../FlaggableElement";
|
||||
|
||||
export function MessageTableEntry({ item, idx }) {
|
||||
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
|
||||
|
||||
return (
|
||||
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
|
||||
<HStack>
|
||||
<Avatar
|
||||
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
|
||||
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
|
||||
/>
|
||||
<LinkBox className={`p-4 rounded-md text-white whitespace-pre-wrap ${bgColor} text-white w-full`}>
|
||||
<NextLink href={`/messages/${item.id}`} passHref>
|
||||
{item.text}
|
||||
</NextLink>
|
||||
</LinkBox>
|
||||
</HStack>
|
||||
</FlaggableElement>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { JWT } from "next-auth/jwt";
|
||||
|
||||
class OasstError {
|
||||
message: string;
|
||||
errorCode: number;
|
||||
httpStatusCode: number;
|
||||
|
||||
constructor(message: string, errorCode: number, httpStatusCode: number) {
|
||||
this.message = message;
|
||||
this.errorCode = errorCode;
|
||||
this.httpStatusCode = httpStatusCode;
|
||||
}
|
||||
}
|
||||
|
||||
export default class OasstApiClient {
|
||||
constructor(private readonly oasstApiUrl: string, private readonly oasstApiKey: string) {}
|
||||
|
||||
private async post(path: string, body: any): Promise<any> {
|
||||
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"X-API-Key": this.oasstApiKey,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (resp.status == 204) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (resp.status >= 300) {
|
||||
const errorText = await resp.text();
|
||||
try {
|
||||
const error = JSON.parse(errorText);
|
||||
throw new OasstError(error.message, error.error_code, resp.status);
|
||||
} catch (e) {
|
||||
throw new OasstError(errorText, 0, resp.status);
|
||||
}
|
||||
}
|
||||
|
||||
return await resp.json();
|
||||
}
|
||||
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to store a task in RegisteredTask.task.
|
||||
// This is a raw Json type, so we can't use it to strongly type the task.
|
||||
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
|
||||
return this.post("/api/v1/tasks/", {
|
||||
type: taskType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name || userToken.email,
|
||||
auth_method: "local",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async ackTask(taskId: string, messageId: string): Promise<void> {
|
||||
return this.post(`/api/v1/tasks/${taskId}/ack`, {
|
||||
message_id: messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -36,22 +36,24 @@ export default function Account() {
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<form onSubmit={updateUser}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
placeholder="Edit Username"
|
||||
type="text"
|
||||
value={username}
|
||||
></Input>
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
</InputGroup>
|
||||
</form>
|
||||
</main>
|
||||
<div className="oa-basic-theme">
|
||||
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<form onSubmit={updateUser}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
placeholder="Edit Username"
|
||||
type="text"
|
||||
value={username}
|
||||
></Input>
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
</InputGroup>
|
||||
</form>
|
||||
</main>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,13 +19,15 @@ export default function Account() {
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<Button>
|
||||
<Link href="/account/edit">Edit Username</Link>
|
||||
</Button>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
<div className="oa-basic-theme">
|
||||
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<Button>
|
||||
<Link href="/account/edit">Edit Username</Link>
|
||||
</Button>
|
||||
<p>{session.user.email}</p>
|
||||
</main>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,29 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
//TODO: add params if needed
|
||||
const params = new URLSearchParams({
|
||||
username: token.sub,
|
||||
});
|
||||
|
||||
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
},
|
||||
});
|
||||
const messages = await messagesRes.json();
|
||||
|
||||
// Send recieved messages to the client.
|
||||
res.status(200).json(messages);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -1,4 +1,5 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import OasstApiClient from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
@@ -20,25 +21,10 @@ const handler = async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
|
||||
|
||||
// Fetch the new task.
|
||||
//
|
||||
// This needs to be refactored into an easier to use library.
|
||||
const taskRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: task_type,
|
||||
user: {
|
||||
id: token.sub,
|
||||
display_name: token.name || token.email,
|
||||
auth_method: "local",
|
||||
},
|
||||
}),
|
||||
});
|
||||
const task = await taskRes.json();
|
||||
const task = await oasstApiClient.fetchTask(task_type, token);
|
||||
|
||||
// Store the task and link it to the user..
|
||||
const registeredTask = await prisma.registeredTask.create({
|
||||
@@ -53,16 +39,7 @@ const handler = async (req, res) => {
|
||||
});
|
||||
|
||||
// Update the backend with our Task ID
|
||||
await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message_id: registeredTask.id,
|
||||
}),
|
||||
});
|
||||
await oasstApiClient.ackTask(task.id, registeredTask.id);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
|
||||
@@ -36,6 +36,7 @@ const handler = async (req, res) => {
|
||||
|
||||
// Send the interaction to the Task Backend. This automatically fetches the
|
||||
// next task in the sequence (or the done task).
|
||||
// TODO(#353): Move this into OasstApiClient.
|
||||
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/interaction`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Container, Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
@@ -21,6 +21,12 @@ const AssistantReply = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Container, Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
@@ -27,6 +27,12 @@ const InitialPrompt = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputRef.current.value.trim();
|
||||
trigger({
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
@@ -21,6 +21,12 @@ const UserReply = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Box, useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
@@ -27,11 +28,6 @@ const Dashboard = () => {
|
||||
);
|
||||
};
|
||||
|
||||
Dashboard.getLayout = (page) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
{page}
|
||||
</div>
|
||||
);
|
||||
Dashboard.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default Dashboard;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { ContextMessages } from "src/components/ContextMessages";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
@@ -26,6 +26,12 @@ const RankAssistantReplies = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
@@ -32,6 +32,12 @@ const RankInitialPrompts = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
id: task.id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { ContextMessages } from "src/components/ContextMessages";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
@@ -26,6 +26,12 @@ const RankUserReplies = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import { SideMenu } from "src/components/Dashboard";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
const MessagesDashboard = () => {
|
||||
const bgColor = useColorModeValue(colors.light.bg, colors.dark.bg);
|
||||
const boxBgColor = useColorModeValue("white", "gray.700");
|
||||
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
|
||||
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [userMessages, setUserMessages] = useState([]);
|
||||
|
||||
const { isLoading: isLoadingAll } = useSWRImmutable("/api/messages", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setMessages(data);
|
||||
},
|
||||
});
|
||||
|
||||
const { isLoading: isLoadingUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setUserMessages(data);
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Messages - Open Assistant</title>
|
||||
<meta name="description" content="Chat with Open Assistant and provide feedback." />
|
||||
</Head>
|
||||
<Box backgroundColor={bgColor} className="sm:overflow-hidden">
|
||||
<Box className="sm:flex h-full gap-6">
|
||||
<Box className="p-6 sm:pr-0">
|
||||
<SideMenu />
|
||||
</Box>
|
||||
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
|
||||
<SimpleGrid columns={[1, 1, 1, 2]} gap={4}>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
<Text className="text-2xl font-bold" pb="4">
|
||||
Your most recent messages
|
||||
</Text>
|
||||
<Box
|
||||
backgroundColor={boxBgColor}
|
||||
boxShadow="base"
|
||||
dropShadow={boxAccentColor}
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
MessagesDashboard.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default MessagesDashboard;
|
||||
@@ -14,7 +14,7 @@ const PrivacyPolicy = () => {
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main>
|
||||
<main className="oa-basic-theme">
|
||||
<Container>
|
||||
<Heading as="h1" size="3xl">
|
||||
Privacy Policy
|
||||
|
||||
@@ -14,7 +14,7 @@ const TermsOfService = () => {
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<main>
|
||||
<main className="oa-basic-theme">
|
||||
<Container>
|
||||
<Heading as="h1" size="3xl">
|
||||
Terms Of Service
|
||||
|
||||
Reference in New Issue
Block a user