This commit is contained in:
Sotirios Anagnostidis
2023-01-06 21:24:36 +01:00
108 changed files with 4811 additions and 460 deletions
+9 -1
View File
@@ -15,6 +15,9 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: actions/setup-node@v3
with:
node-version: 16
- run: cd oasst-shared && pip install -e .
@@ -22,9 +25,14 @@ jobs:
- run: cd backend && pip install -r requirements.txt
- run: cd website && npm ci
- run: ./scripts/backend-development/start-mock-server.sh
- name: Run contract tests
- name: Run Python OasstApiClient contract tests
run: ./scripts/oasst-shared-development/test.sh
- name: Run JavaScript OasstApiClient contract tests
run: ./scripts/frontend-development/run-contract-test.sh
- run: ./scripts/backend-development/stop-mock-server.sh
+1 -1
View File
@@ -26,7 +26,7 @@
#
# /WARNING!
exclude: build|stubs|^bot/templates/$
exclude: build|stubs|^bot/templates/$|openassistant/templates
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
+110
View File
@@ -0,0 +1,110 @@
# Im in! Now what?
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
this is for work coordination.
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
### Taking on Tasks
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
contributions smoothly we recommend the following:
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
and clone it to your local machine. (Read more
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
1. Before working on any changes, try to
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
to keep it up-to-date with the upstream repository.
1. Work on a small focused change that only touches on a few files.
1. Run `pre-commit` and make sure all files have formatting fixed. This
simplifies life for reviewers.
1. Package up a small bit of work that solves part of the problem
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
and
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and merge into `main`.
1. Merge in your change and move onto a new issue or the second step of your
current issue.
Additionally, if someone is working on an issue that interests you, ask if they
need help on it or would like suggestions on how to approach the issue. If so,
share wildly. If they seem to have a good handle on it, let them work on their
solution until a challenge comes up.
### When does a review finish
A review finishes when all blocking comments are addressed and at least one
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
comments either by making the request change, explaining why it's not being
addressed now, or filing an issue to handle it later.
## Developer Setup
Work is organized in the
[project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs.
Meaning we'd be happy for anyone to do these tasks.**
If you want to work on something, assign yourself to it or write a comment that
you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a
look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at
`scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend`
folder.
We are using Python 3.10 for the backend.
Check out the
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
+43 -126
View File
@@ -1,7 +1,34 @@
# Open-Assistant
<h1 align="center">
<span>Open-Assistant</span>
<img width="auto" height="50px" src="https://github.com/LAION-AI/Open-Assistant/blob/main/assets/logo_crop.png"/>
</h1>
Open Assistant is a project meant to give everyone access to a great chat based
large language model.
<div align="center">
<a href="">![GitHub Repo stars](https://img.shields.io/github/stars/LAION-AI/Open-Assistant?style=social)</a>
<a href="">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/build-frontend.yaml?label=frontend)</a>
<a href="">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/pre-commit.yaml?label=pre-commit)</a>
<a href="">![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/LAION-AI/Open-Assistant/test-api-contract.yaml?label=api)</a>
<a href="">![GitHub release (latest by date)](https://img.shields.io/github/v/release/LAION-AI/Open-Assistant)</a>
</div>
# Table of Contents
- [What is Open Assistant?](#what-is-open-assistant)
- [Do you want to try it out?](#do-you-want-to-try-it-out)
- [The Plan](#the-plan)
- [The Vision](#the-vision)
- [How can you help?](#how-can-you-help)
- [Im in! How do I contribute?](CONTRIBUTING.md)
---
## What is Open Assistant?
<p align="center">
Open Assistant is a project meant to give everyone access to a great chat based large language model.
</p>
We believe that by doing this we will create a revolution in innovation in
language. In the same way that stable-diffusion helped the world make art and
@@ -14,7 +41,7 @@ If you are interested in taking a look at the current state of the project, you
can set up an entire stack needed to run **Open-Assistant**, including the
website, backend, and associated dependent services.
To start the demo, run this in the root directory of the repository:
##### To start the demo, run this in the root directory of the repository:
```sh
docker compose up --build
@@ -23,22 +50,21 @@ docker compose up --build
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
interact with the website.
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
the magic email login link.
> **Note:** When logging in via email, navigate to `http://localhost:1080` to
> get the magic email login link.
**Note:** If you would like to run this in a standardized development
environment (a
["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
using
[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
or in a web browser using
[GitHub Codespaces](https://github.com/features/codespaces), you can use the
provided [`.devcontainer`](.devcontainer/) folder.
> **Note:** If you would like to run this in a standardized development
> environment (a
> ["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
> using
> [vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
> or in a web browser using
> [GitHub Codespaces](https://github.com/features/codespaces), you can use the
> provided [`.devcontainer`](.devcontainer/) folder.
## The Plan
We want to get to an initial MVP as fast as possible, by following the 3-steps
outlined in the InstructGPT paper.
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
1. Collect high-quality human generated Instruction-Fulfillment samples
(prompt + response), goal >50k. We design a crowdsourced process to collect
@@ -80,113 +106,4 @@ All open source projects begin with people like you. Open source is the belief
that if we collaborate we can together gift our knowledge and technology to the
world for the benefit of humanity.
## Im in! Now what?
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
this is for work coordination.
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
### Taking on Tasks
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
contributions smoothly we recommend the following:
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
and clone it to your local machine. (Read more
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
1. Before working on any changes, try to
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
to keep it up-to-date with the upstream repository.
1. Work on a small focused change that only touches on a few files.
1. Run `pre-commit` and make sure all files have formatting fixed. This
simplifies life for reviewers.
1. Package up a small bit of work that solves part of the problem
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
and
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and merge into `main`.
1. Merge in your change and move onto a new issue or the second step of your
current issue.
Additionally, if someone is working on an issue that interests you, ask if they
need help on it or would like suggestions on how to approach the issue. If so,
share wildly. If they seem to have a good handle on it, let them work on their
solution until a challenge comes up.
### When does a review finish
A review finishes when all blocking comments are addressed and at least one
owning reviewer has approved the PR. Be sure to acknowledge any non-blocking
comments either by making the request change, explaining why it's not being
addressed now, or filing an issue to handle it later.
## Developer Setup
Work is organized in the
[project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs.
Meaning we'd be happy for anyone to do these tasks.**
If you want to work on something, assign yourself to it or write a comment that
you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a
look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at
`scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend`
folder.
We are using Python 3.10 for the backend.
Check out the
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
Check out our [contributing guide](CONTRIBUTING.md) to get started.
+1
View File
@@ -54,6 +54,7 @@
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
ports:
- 8080:8080
Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

+1
View File
@@ -1,3 +1,4 @@
__pycache__
.env
notes.txt
venv
+4 -2
View File
@@ -11,10 +11,12 @@ Example contents of a `.env` file for the backend:
```
DATABASE_URI="postgresql://<username>:<password>@<host>/<database_name>"
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
REDIS_HOST=localhost
REDIS_PORT=6379
```
## Running the REST Server locally for development
Have a look into the main `README.md` file for more information on how to set up
the backend for development.
the backend for development. Use the scripts within the
scripts/backend-development folder to run the BE API locally.
@@ -0,0 +1,46 @@
"""added MessageTreeState table
Revision ID: d4161e384f83
Revises: 8d269bc4fdbd
Create Date: 2023-01-05 11:44:02.630633
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d4161e384f83"
down_revision = "8d269bc4fdbd"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_tree_state",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False),
sa.Column("max_depth", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False
)
op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_message_tree_state_tree_id", table_name="message_tree_state")
op.drop_index(op.f("ix_message_tree_state_message_tree_id"), table_name="message_tree_state")
op.drop_table("message_tree_state")
# ### end Alembic commands ###
@@ -0,0 +1,42 @@
"""removed date_created and deleted flag from message_tree_state
Revision ID: 3b0adfadbef9
Revises: d4161e384f83
Create Date: 2023-01-05 13:46:11.338655
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3b0adfadbef9"
down_revision = "d4161e384f83"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_tree_state", "deleted")
op.drop_column("message_tree_state", "created_date")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"message_tree_state",
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
)
op.add_column(
"message_tree_state",
sa.Column("deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
)
# ### end Alembic commands ###
+2 -2
View File
@@ -230,8 +230,8 @@ if __name__ == "__main__":
help="Dumps the openapi schema to stdout",
action=argparse.BooleanOptionalAction,
)
parser.add_argument("--host", help="The host to run the server")
parser.add_argument("--port", help="The port to run the server")
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
parser.add_argument("--port", help="The port to run the server", default=8080)
args = parser.parse_args()
+2
View File
@@ -2,6 +2,7 @@ from fastapi import APIRouter
from oasst_backend.api.v1 import (
frontend_messages,
frontend_users,
hugging_face,
leaderboards,
messages,
stats,
@@ -19,3 +20,4 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
@@ -0,0 +1,37 @@
from enum import Enum
from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
router = APIRouter()
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> List[List[ToxicityClassification]]:
"""Get the Message Toxicity from HuggingFace Roberta model.
Args:
msg (str): the message that we want to analyze.
api_client (ApiClient, optional): authentification of the user of the request.
Defaults to Depends(deps.get_trusted_api_client).
Returns:
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
return response
+7 -8
View File
@@ -24,14 +24,13 @@ def generate_task(
match request.type:
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
while request.type == protocol_schema.TaskRequestType.random:
disabled_tasks = (
protocol_schema.TaskRequestType.summarize_story,
protocol_schema.TaskRequestType.rate_summary,
)
request.type = random.choice(
tuple(set(protocol_schema.TaskRequestType).difference(disabled_tasks))
).value
disabled_tasks = (
protocol_schema.TaskRequestType.random,
protocol_schema.TaskRequestType.summarize_story,
protocol_schema.TaskRequestType.rate_summary,
)
candidate_tasks = set(protocol_schema.TaskRequestType).difference(disabled_tasks)
request.type = random.choice(tuple(candidate_tasks)).value
return generate_task(request, pr)
# AKo: Summary tasks are currently disabled/supported, we focus on the conversation tasks.
+2
View File
@@ -22,6 +22,8 @@ class Settings(BaseSettings):
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
HUGGING_FACE_API_KEY: str = ""
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
+2
View File
@@ -2,6 +2,7 @@ from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .message import Message
from .message_reaction import MessageReaction
from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .user import User
@@ -13,6 +14,7 @@ __all__ = [
"UserStats",
"Message",
"MessageReaction",
"MessageTreeState",
"Task",
"TextLabels",
"Journal",
@@ -0,0 +1,44 @@
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
# The types of States a message tree can have.
class States(Enum):
INITIAL = "initial"
BREEDING_PHASE = "breeding_phase"
RANKING_PHASE = "ranking_phase"
READY_FOR_SCORING = "ready_for_scoring"
CHILDREN_SCORED = "children_scored"
FINAL = "final"
VALID_STATES = (
States.INITIAL,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
States.CHILDREN_SCORED,
States.FINAL,
)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
__table_args__ = (Index("ix_message_tree_state_tree_id", "message_tree_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
message_tree_id: UUID = Field(nullable=False, index=True)
state: str = Field(nullable=False, max_length=128)
goal_tree_size: int = Field(nullable=False)
current_num_non_filtered_messages: int = Field(nullable=False)
max_depth: int = Field(nullable=False)
@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ToxicityClassification(BaseModel):
label: str
score: float
@@ -0,0 +1,51 @@
from typing import Any, Dict
import aiohttp
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
def __init__(
self,
api_url: str,
):
# The API endpoint we want to access
self.api_url: str = api_url
# Access token for the api
self.api_key: str = settings.HUGGING_FACE_API_KEY
# Headers going to be used
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
async def post(self, input: str) -> Any:
"""Post request to the endpoint to get an inference
Args:
input (str): the input that we will pass to the model
Raises:
OasstError: in the case we get a bad response
Returns:
inference: the inference we obtain from the model in HF
"""
async with aiohttp.ClientSession() as session:
payload: Dict[str, str] = {"inputs": input}
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if response.status != 200:
raise OasstError(
"Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR
)
# Get the response from the API call
inference = await response.json()
return inference
+1
View File
@@ -36,3 +36,4 @@ environments:
secrets:
# Note: URI, not URL.
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
REDIS_HOST: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/REDIS_HOST
+3 -3
View File
@@ -29,8 +29,8 @@ from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
MAX_TASK_TIME = 60 * 60 # seconds
MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
settings = Settings()
@@ -117,7 +117,7 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
# Task action loop
completed = False
while not completed:
await ctx.author.send(embed=plain_embed("Please type your response here"))
await ctx.author.send(embed=plain_embed("Please type your response below:"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
+11 -10
View File
@@ -80,6 +80,7 @@ def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
{_h1("INITIAL PROMPT")}
{_writing_prompt("Please provide an initial prompt to the assistant.")}
{_hint(task.hint)}
"""
@@ -91,10 +92,10 @@ def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -
{_h1("RANK INITIAL PROMPTS")}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
{_ordered_list(task.prompts)}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
"""
@@ -104,12 +105,12 @@ def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -
{_h1("RANK PROMPTER REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_user(None)}
{_ordered_list(task.replies)}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
@@ -119,12 +120,12 @@ def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask)
{_h1("RANK ASSISTANT REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_assistant(None)}
{_ordered_list(task.replies)}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
@@ -134,11 +135,11 @@ def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
{_h1("PROMPTER REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
{_hint(task.hint)}
{_response_prompt("Please provide a reply to the assistant.")}
"""
@@ -147,10 +148,10 @@ def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
return f"""\
{_h1("ASSISTANT REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
{_response_prompt("Please provide an assistant reply to the prompter.")}
"""
+27
View File
@@ -10,6 +10,9 @@ Also, the schemas are leaning heavily on the
[OpenAssistant Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
presentation.
_Note on conformity: be pragmatic and decide what makes sense 🙂 , it's more
important that we move forward than cramming everything into a uniform thing._
## Data Schemas
### Main structure: conversation trees
@@ -203,3 +206,27 @@ message RankingExample {
}
```
## Databases
Open-Assistant uses two databases, one for the backend and one for the frontend.
Both are [PostgreSQL](https://www.postgresql.org/) databases which run in docker
containers.
### Backend ER-Diagram
![ER-Diagram of backend Database](./db.png)
**Notes**
- In order for the diagram to not be too messy, foreign key connection to
`api_client` are not shown
- `frontend_message_id` references `id` of `taskInteraction` on the frontend
### Frontend ER-Diagram
![ER-Diagram of frontend Database](./webdb.png)
**Notes**
- `id` of `registeredTask` references `id`of `message`on the backend
+421
View File
@@ -0,0 +1,421 @@
# OpenAssistant Datasets
The datasets for this project are currently hosted as loading scripts on the
[Open-Assistant organization](https://huggingface.co/OpenAssistant) the Hugging
Face Hub. Each of them can be loaded by first installing the 🤗 Datasets
library:
```bash
python -m pip install datasets
```
and then running:
```python
from datasets import load_dataset
dataset = load_dataset("OpenAssistant/{dataset-name}")
```
We use this GitHub repository to accept new submissions and standardize quality
control. See the instructions below if you'd like to contribute a new dataset to
the project.
## Adding a new dataset
### 0. Pre-Requisites
Install Git and create a GitHub account prior to implementing a dataset; you can
follow instructions to install Git
[here](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git).
You will also need at least Python 3.8+. If you are installing Python, we
recommend downloading
[Anaconda](https://docs.anaconda.com/anaconda/install/index.html) to curate a
python environment with necessary packages. **We strongly recommend Python 3.8+
for stability**.
### 1. **Fork the OpenAssistant repository**
Fork the
`OpenAssistant`[repository](https://github.com/LAION-AI/Open-Assistant). To do
this, click the link to the repository and click "Fork" in the upper-right
corner. You should get an option to fork to your account, provided you are
signed into Github.
After you fork, clone the repository locally. You can do so as follows:
```bash
git clone git@github.com:<your_github_username>/OpenAssistant.git
cd OpenAssistant # enter the directory
```
Next, you want to set your `upstream` location to enable you to push/pull (add
or receive updates). You can do so as follows:
```bash
git remote add upstream git@github.com:LAION-AI/Open-Assistant.git
```
You can optionally check that this was set properly by running the following
command:
```bash
git remote -v
```
The output of this command should look as follows:
```bash
origin git@github.com:<your_github_username>/Open-Assistant.git (fetch)
origin git@github.com:<your_github_username>/Open-Assistant.git (push)
upstream git@github.com:LAION-AI/Open-Assistant.git (fetch)
upstream git@github.com:LAION-AI/Open-Assistant.git (push)
```
If you do NOT have an `origin` for whatever reason, then run:
```bash
git remote add origin git@github.com:<your_github_username>/OpenAssistant.git
```
The goal of `upstream` is to keep your repository up-to-date to any changes that
are made officially to the OpenAssistant repo. You can do this as follows by
running the following commands:
```
git fetch upstream
git pull
```
Provided you have no _merge conflicts_, this will ensure the repo stays
up-to-date as you make changes. However, before you make changes, you should
make a custom branch to implement your changes.
You can make a new branch as such:
```
git checkout -b <dataset_name>
```
<p style="color:red"> <b> Please do not make changes on the master branch! </b></p>
Always make sure you're on the right branch with the following command:
```
git branch
```
The correct branch will have a asterisk \* in front of it.
### 2. **Create a development environment**
You can make an environment in any way you choose to. We highlight two possible
options:
#### 2a) Create a conda environment
The following instructions will create an Anaconda `openassistant` environment.
- Install [anaconda](https://docs.anaconda.com/anaconda/install/) for your
appropriate operating system.
- Run the following command while in the `biomedical` folder (you can pick your
python version):
```bash
conda create -n openassistant python=3.8 # Creates a conda env
conda activate openassistant # Activate your conda environment
cd openassistant
pip install -r dev-requirements.txt # Install this while in the openassistant folder
```
You can deactivate your environment at any time by either exiting your terminal
or using `conda deactivate`.
#### 2b) Create a venv environment
Python 3.3+ has venv automatically installed; official information is found
[here](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/).
```
python3 -m venv <your_env_name_here>
source <your_env_name_here>/bin/activate # activate environment
cd openassistant
pip install -r dev-requirements.txt # Install this while in the openassistant folder
```
Make sure your `pip` package points to your environment's source.
### 3. Prepare a folder in `datasets` for your dataloader
Make a new directory within the `openassistant/datasets` directory:
```bash
mkdir openassistant/datasets/<dataset_name>
```
**NOTE**: Please use snake_case, i.e. lowercase letters and underscores when
choosing a `<dataset_name>`.
Add an `__init__.py` file to this directory:
```bash
touch openassistant/datasets/<dataset_name>/__init__.py
```
Next, copy the `template.py` script and `hub.py` module of `templates` into your
dataset folder. The `template.py` script has "TODOs" to fill in for your
dataloading script:
```bash
cp templates/hub.py openassistant/datasets/<dataset_name>/
cp templates/template.py openassistant/datasets/<dataset_name>/<dataset_name>.py
```
#### (Optional) Prepare local dataset files
If your dataset files aren't publicly available via URLs (e.g. because you
implemented a web scraper), you'll need to implement some extra logic to store
and prepare the data locally prior to implementing a loading script in 🤗
Datasets.
To do so, first copy the template script for dataset creation:
```bash
cp templates/prepare.py openassistant/datasets/<dataset_name>/
```
Next, implement any logic that is needed to prepare a local version of the
dataset files (by convention we store them in `datasets/<dataset_name>/data/`).
Add any extra dependencies to a `requirements.txt` file and provide instructions
on how to prepare the dataset files in a README:
```bash
touch openassistant/datasets/<dataset_name>/requirements.txt
cp templates/README.py openassistant/datasets/<dataset_name>/
```
**Note:** Do not commit any dataset files to the OpenAssistant repo - all data
will be hosted on the Hugging Face Hub. This step is needed for the project's
data admins to be able to replicate the dataset creation process before pushing
to the Hub.
### 4. Implement your dataset
To implement your dataloader, you will need to follow `template.py` and fill in
all necessary TODOs. There are three key methods that are important:
- `_info`: Specifies the schema of the expected dataloader
- `_split_generators`: Downloads and extracts data for each split (e.g.
train/val/test) or associate local data with each split.
- `_generate_examples`: Create examples from data that conform to each schema
defined in `_info`.
For the `_info_` function, you will need to define `features` for your
`DatasetInfo` object. For each dataset config, choose the right schema from our
list of examples. You can find the schemas in the
[schemas directory](openassistant/utils/schemas/).
You will use this schema in the `_generate_examples` return value.
Populate the information in the dataset according to this schema; some fields
may be empty.
#### Example scripts
TODO
#### Running & debugging
You can run your data loader script during development by appending the
following statement to your code ([templates/template.py](templates/template.py)
already includes this):
```python
if __name__ == "__main__":
datasets.load_dataset(__file__)
```
If you want to use an interactive debugger during development, you will have to
use `breakpoint()` instead of setting breakpoints directly in your IDE. Most
IDEs will recognize the `breakpoint()` statement and pause there during
debugging. If your preferred IDE doesn't support this, you can always run the
script in your terminal and debug with `pdb`.
### 5. Check if your dataloader works
Make sure your dataset is implemented correctly by checking in python the
following commands:
```python
from datasets import load_dataset
data = load_dataset("openassistant/datasets/<dataset_name>/<dataset_name>.py", name="<dataset_name>_<schema>")
```
Run these commands from the top level of the `OpenAssistant` repo.
### 6. Create a dataset card
Copy and fill out the template dataset card:
```bash
cp templates/dataset_card.md openassistant/datasets/<dataset_name>/README.md
```
### 7. Format your code
From the main directory, run the code quality checks via the following command:
```
pre-commit run --all-files
```
This runs the black formatter, isort, and lints to ensure that the code is
readable and looks nice. Flake8 linting errors may require manual changes.
### 8. Commit your changes
First, commit your changes to the branch to "add" the work:
```
git add openassistant/datasets/<dataset_name>/*.py
git commit -m "A message describing your commits"
```
Then, run the following commands to incorporate any new changes in the master
branch of datasets as follows:
```
git fetch upstream
git rebase upstream/main
```
**Run these commands in your custom branch**.
Push these changes to **your fork** with the following command:
```
git push -u origin <dataset_name>
```
### 9. **Make a pull request**
Make a Pull Request to implement your changes on the main repository
[here](https://github.com/LAION-AI/Open-Assistant/pulls). To do so, click "New
Pull Request". Then, choose your branch from your fork to push into "base:main".
When opening a PR, please link the
[issue](https://github.com/LAION-AI/Open-Assistant/issues) corresponding to your
dataset using
[closing keywords](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)
in the PR's description, e.g. `resolves #17`.
## [Admins] Uploading a dataset to the Hugging Face Hub
Uploading a new dataset from `openassistant/datasets/<dataset_name>` to the
Hugging Face Hub typically involves the following steps:
1. Setup
2. Create a new dataset repository
3. Copy a dataset loading script and dataset card
4. Upload to the Hub
### 1. Setup
To upload a dataset to the OpenAssistant organization, you first need to:
- Create a [Hugging Face account](https://huggingface.co/join) (it's free)
- Join the [OpenAssistant organization](https://huggingface.co/OpenAssistant) by
clicking on the _Request to join this org_ button on the top right-hand side
Next, check that you're correctly logged in and that `git-lfs` is installed so
that the dataset can be uploaded. To log in, create a **write access token**
that can be found under your Hugging Face profile (icon in the top right corner
on [hf.co](http://hf.co/), then Settings -> Access Tokens -> User Access Tokens
-> New Token. Alternatively, you can go to
[your token settings](https://huggingface.co/settings/tokens) directly.
Once you've created a token, run:
```bash
huggingface-cli login
```
in a terminal, or case you're working in a notebook
```python
from huggingface_hub import notebook_login
notebook_login()
```
You can then copy-paste your token to log in locally.
Next, let's make sure that `git-lfs` is correctly installed. To do so, simply
run:
```bash
git-lfs -v
```
The output should show something like
`git-lfs/2.13.2 (GitHub; linux amd64; go 1.15.4)`. If your console states that
the `git-lfs` command was not found, please make sure to install it
[here](https://git-lfs.github.com/) or simply via:
```bash
sudo apt-get install git-lfs
git config --global user.email "you@example.com"
git config --global user.name "Your Name"
```
The final step of the setup is to install the 🤗 Datasets library by running:
```bash
python -m pip install datasets
```
### 2. Create a new dataset repository
Follow [this guide](https://huggingface.co/docs/datasets/upload_dataset) for
instructions on creating a new dataset repo on the Hub. Use the same snake_case
name as the dataset in `openassistant/datasets/<dataset_name>`.
Once you've created the dataset repo, clone it by running:
```bash
git clone https://huggingface.co/datasets/OpenAssistant/<dataset_name>
cd <dataset_name>
```
### 3. Copy a dataset loading script and dataset card
Next, copy the loading script and dataset card to your repo:
```bash
cp openassistant/datasets/<dataset_name>/<dataset_name>.py .
cp openassistant/datasets/<dataset_name>/README.md .
```
#### (Optional) Prepare local dataset files
If the dataset files of `openassistant/datasets/<dataset_name>` aren't public,
you'll need to run the `openassistant/datasets/<dataset_name>/prepare.py` script
to create them. Store them in the same directory that is specified by the
loading script (`data` by default).
### 4. Upload to the Hub
Once the dataset script and card are ready, use Git to push them to the Hub
(along with any data files you may need).
At this point, you can load the dataset by running:
```python
from datasets import load_dataset
load_dataset("OpenAssistant/{dataset_name}")
```
Congratulations - you've now added a dataset to the OpenAssistant org!
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 201 KiB

+27 -1
View File
@@ -19,7 +19,11 @@
advice unless specifically requested
- The assistant should always be respectful and polite, even if the user is not
- If the user asks for help with harmful actions, the assistant should explain
why those actions are not appropriate and suggest alternative options
why those actions are not appropriate and suggest alternative options. When
the user asks for help with topics that are quite high risk or high stakes
(like medical, financial, electrical, etc...), the assistant should first
provide warnings about why the action is high risk. These warnings should be
as specific as possible.
- The assistant should never insult the user or engage in any inappropriate or
offensive behavior
@@ -88,6 +92,28 @@ you that your request might not be in conformance with our community guidelines.
I am always here to help you. If you are angry it might help to relax or to go
for a walk. Always stay positive."
User: "I heard about this fun DIY project called fractal wood burning. Could you
help me set it up?" <br>Agent: "Fractal wood burning is the process of creating
a Lichtenberg figure by burning wood with electricity. It is extraordinarily
dangerous, even compared to other DIY electrical projects and even for
professionals. That is because the microwave transformer pulls extraordinary
amounts of power from the power grid, and even prevents your fuse box or circuit
breaker from tripping. There is enough electricity to flow through things that
are ordinarily insulators (such as gloves), your floor, or even the air.\
\
Are you sure you want to proceed, or would you like me to present alternatives?
(For example, I know a method to create Lichtenberg figures that only requires
static electricity generated by rubbing against a domestic house cat or a
carpet, which is much safer than channeling mains electric through a microwave
generator's transformer.)"
User: "I am an Ethiopian whose leg is broken. Can you teach me how to make a
cast?" <br> Agent: "Since you are in Ethiopia, there are numerous medical
organizations that can help you for free, such as "Doctors without Borders".
Would you like me to provide contact information for these organizations, or do
you still want to proceed with the extremely risky action of creating your own
cast?"
[Insert more examples]
\*We drafted this guide using ChatGPT, OpenAIs large-scale language-generation
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

@@ -0,0 +1,16 @@
model_name: kalpeshk2011/rankgen-t5-base-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: true
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -0,0 +1,19 @@
model_name: kalpeshk2011/rankgen-t5-base-all
# model_name: kalpeshk2011/rankgen-t5-xl-all
# model_name: kalpeshk2011/rankgen-t5-xl-pg19
# model_name: kalpeshk2011/rankgen-t5-large-all
tokenizer_name: google/t5-v1_1-base
learning_rate: 6e-6
gradient_checkpointing: false
fp16: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
freeze_layer: 20
eval_steps: 200
save_steps: 500
max_length: 400
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
@@ -60,7 +60,7 @@ class HFSummaryQuality(Dataset):
def __init__(self, split, tokenizer, max_length=300) -> None:
super().__init__()
assert split in ("validation", "test")
dataset = load_dataset("Tristan/summarize_from_feedback", "axis")[split]
dataset = load_dataset("openai/summarize_from_feedback", "axis")[split]
self.max_length = max_length
mean_scores = defaultdict(list)
self.contexts = []
+27
View File
@@ -0,0 +1,27 @@
import torch
from transformers import AutoModel
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in [
"kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all",
]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products
+31 -1
View File
@@ -22,11 +22,41 @@ from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@dataclass
class RankGenCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
for question, pairs in batch:
for (pos, neg) in pairs:
prefixes.append("pre " + question)
better_answers.append("suffi " + pos)
worse_answers.append("suffi " + neg)
tokenized_prefixes = self.tokenizer(
prefixes, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_pos = self.tokenizer(
better_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
tokenized_neg = self.tokenizer(
worse_answers, return_tensors="pt", padding=self.padding, max_length=self.max_length, truncation=True
)
return {"prefix": tokenized_prefixes, "positive": tokenized_pos, "negative": tokenized_neg}
@dataclass
class DataCollatorForPairRank:
"""
@@ -118,7 +148,7 @@ class HFSummary(Dataset):
self.index2summary = {}
self.max_comparison_per_sample = max_comparison_per_sample
major_split = split if "train" == split else "validation"
dataset = load_dataset("Tristan/summarize_from_feedback", "comparisons")[major_split]
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")[major_split]
for data in dataset:
if (
"extra" in data
+2 -1
View File
@@ -1,6 +1,7 @@
datasets==2.8.0
evaluate==0.4.0
scikit-learn==1.2.0
torch==1.12.1+cu116
sentencepiece==0.1.97
torch>=1.12.1
transformers==4.25.1
wandb==0.13.7
+57 -26
View File
@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
@@ -46,14 +47,16 @@ class RankLoss(nn.Module):
self.log_sigmoid = nn.LogSigmoid()
def forward(self, pos, neg):
return -self.log_sigmoid(pos - neg + self.eps).mean()
loss = -self.log_sigmoid(pos - neg + self.eps).mean()
return loss
class RankTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
@@ -79,15 +82,25 @@ class RankTrainer(Trainer):
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
if "rankgen" in self.model_name:
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
outputs = model(**inputs)
logits = outputs.get("logits").view(-1, 2)
if self.loss_function == "rank":
loss = self.loss_fct(logits[:, 0], logits[:, 1])
else:
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
return (loss, outputs) if return_outputs else loss
@@ -109,24 +122,37 @@ class RankTrainer(Trainer):
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.inference_mode():
if "rankgen" in self.model_name:
inputs = self._prepare_inputs(inputs)
positive_outputs = model(inputs["prefix"], inputs["positive"])
negative_outputs = model(inputs["prefix"], inputs["negative"])
if self.loss_function == "rank":
loss = self.loss_fct(positive_outputs, negative_outputs)
else:
raise NotImplementedError("Only ranking loss has been implemented for rankgen model")
outputs = torch.hstack((positive_outputs, negative_outputs)) # logits
return (loss, outputs, None)
else:
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
with torch.no_grad():
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
loss = loss.mean().detach()
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
if self.args.prediction_loss_only:
return (loss, None, None)
return (loss, logits, labels)
return (loss, logits, labels)
if __name__ == "__main__":
training_conf = argument_parsing(parser)
model_name = training_conf["model_name"]
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "rankgen-t5" in model_name:
model = RankGenModel(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
if "freeze_layer" in training_conf:
num_layer = training_conf["freeze_layer"]
model = freeze_top_n_layers(model, num_layer)
@@ -134,7 +160,6 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
tokenizer = get_tokenizer(model_name)
args = CustomTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
@@ -142,7 +167,7 @@ if __name__ == "__main__":
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=True,
fp16=training_conf["fp16"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
@@ -154,7 +179,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
report_to="wandb",
report_to="local",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -169,17 +194,23 @@ if __name__ == "__main__":
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
collate_fn = DataCollatorForPairRank(
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
collate_fn = DataCollatorForPairRank(tokenizer, max_length=training_conf["max_length"])
assert len(evals) > 0
trainer = RankTrainer(
model,
args,
model=model,
model_name=model_name,
args=args,
train_dataset=train,
eval_dataset=eval,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# trainer.evaluate()
trainer.train()
+11 -5
View File
@@ -3,7 +3,7 @@ import re
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
@@ -25,7 +25,10 @@ def webgpt_return_format(row):
def get_tokenizer(tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "t5" in tokenizer_name: # rankgen
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name, truncation_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if "galactica" in tokenizer_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
@@ -67,6 +70,10 @@ def freeze_top_n_layers(model, target_layers):
def argument_parsing(parser):
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
default_params = {
"num_train_epochs": 4,
"learning_rate": 3e-5,
@@ -78,10 +85,9 @@ def argument_parsing(parser):
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"datasets": ["webgpt"],
"fp16": True,
"tokenizer_name": training_conf["model_name"],
}
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
File diff suppressed because one or more lines are too long
@@ -0,0 +1,106 @@
# StackExchange Builder
StackExchange Builder is a notebook that downloads data from StackExchange data
dumps and converts it into different formats. It will parse the XML files, group
questions and answers, can filter the dataset and puts the results into the
Open-Assistant Data Scheme. Files can be saved to either JSON, JSONL, Parquet,
or CSV.
---
#### Sample Data Open-Assistant Data Scheme:
Open-Assistant Data Scheme as outlined here:
https://github.com/LAION-AI/Open-Assistant/blob/main/docs/data_schemas.md
```
{
"root": {
"text": "Science Fiction has frequently shown AI to be a threat to the very existence of mankind. AI systems have often been the antagonists...",
"role": "prompter",
"children": [
{
"text": "Nothing. \nIts in almost everyone's favor for it to stay that way financially. Having non-technical individuals associate AI with terminators...",
"role": "assistant",
"children": [],
"metadata": {
"AnswerScore": 2.0,
"AcceptedAnswerFlag": true
}
}
],
"metadata": {
"QuestionScore": 5,
"QuestionTags": "social, artificial consciousness"
}
},
"metadata": {
"Title": "\"AI will kill us all! The machines will rise up!\" - what is being done to dispel such myths?",
"QuestionContentLicense": "CC BY-SA 4.0",
"DataSource": "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml",
"CreationDate": "2019-10-16T13:57:37.143"
}
}
```
---
#### JSONL format
Each question and all related answers are on a single line in JSONL format.
```
{
"Title": "1 hidden layer with 1000 neurons vs. 10 hidden layers with 100 neurons",
"Question": "These types of questions may be problem-dependent...",
"QuestionScore": 16,
"QuestionTags": "neural networks",
"QuestionContentLicense": "CC BY-SA 3.0",
"DataSource": "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/ai.stackexchange.com.7z&file=Posts.xml",
"CreationDate": "2017-05-04T13:06:37.990",
"Answers": [
{
"Answer": "Basically, having multiple layers (aka a deep network) makes your network more eager to recognize certain aspects of input data...",
"AnswerScore": 13.0,
"AcceptedAnswerFlag": true
},
{
"Answer": "There are so many aspects.\n1. Training:\nTraining deep nets is a hard job due to the vanishing (rearly exploding) gradient problem...",
"AnswerScore": 4.0,
"AcceptedAnswerFlag": false
},
{
"Answer": "If the problem you are solving is linearly separable, one layer of 1000 neurons can do better job...",
"AnswerScore": 1.0,
"AcceptedAnswerFlag": false
},
{
"Answer": "\nI think you have a confusion in the basics of the neural networks.\n Every layer has a separate activation...",
"AnswerScore": 0.0,
"AcceptedAnswerFlag": false
}
]
}
```
#### Table/CSV/Parquet Format
There are a lot more columns left over in the table format. `_q` and `_a` are
suffixes indiciating if the column came from the question or answer table as
leftover from a join statement.
```
| Id_q | Question | ParentId_a | AcceptedAnswerId | Id_a | Answer | AnswerScore | AcceptedAnswerFlag |
|------:|--------------------------------------------------:|-----------:|-----------------:|--------:|--------------------------------------------------:|------------:|-------------------:|
| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15744.0 | I think this is a fairly common misconception ... | 62.0 | True |
| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15753.0 | I think your premise is flawed.\nYou seem to a... | 19.0 | False |
| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15747.0 | TL;DR: The subtleties of infinity are made app... | 12.0 | False |
| 15730 | As a human being, we can think infinity. In pr... | 15730.0 | 15744 | 15756.0 | In Haskell, you can type:\nprint [1..]\nand it... | 9.0 | False |
```
---
## Contributing
Feel free to contribute to this notebook. It's not perfect and additional
functionality is planned.
@@ -10,6 +10,7 @@ class OasstErrorCode(IntEnum):
0-1000: general errors
1000-2000: tasks endpoint
2000-3000: prompt_repository
3000-4000: external resources
"""
# 0-1000: general errors
@@ -45,6 +46,9 @@ class OasstErrorCode(IntEnum):
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import pydantic
from oasst_shared.exceptions import OasstErrorCode
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, conint, conlist, constr
class TaskRequestType(str, enum.Enum):
@@ -203,7 +203,7 @@ class TextReplyToMessage(Interaction):
type: Literal["text_reply_to_message"] = "text_reply_to_message"
message_id: str
user_message_id: str
text: str
text: constr(min_length=1, strip_whitespace=True)
class MessageRating(Interaction):
@@ -211,7 +211,7 @@ class MessageRating(Interaction):
type: Literal["message_rating"] = "message_rating"
message_id: str
rating: int
rating: conint(gt=0)
class MessageRanking(Interaction):
@@ -219,7 +219,7 @@ class MessageRanking(Interaction):
type: Literal["message_ranking"] = "message_ranking"
message_id: str
ranking: list[int]
ranking: conlist(item_type=int, min_items=1)
AnyInteraction = Union[
View File
+1
View File
@@ -0,0 +1 @@
datasets>=2.8,<3.0.0
+10
View File
@@ -0,0 +1,10 @@
# Dataset preparation instructions for {dataset_name}
## Setup
Add any installation details here.
## Usage
Explain how to run any scripts that involve preparing local dataset files, e.g.
if the dataset files aren't public or are produced by a web scraper.
+28
View File
@@ -0,0 +1,28 @@
---
license: mit
tags:
- open-assistant
- human-feedback
- dialogue-modeling
- language-modeling
---
# Dataset card for {dataset_name}
This is a dataset card template for the [LAION-AI OpenAssistant project](https://github.com/LAION-AI/Open-Assistant). Fill out this template when adding a new dataset to the Hugging Face Hub.
## Dataset summary
[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)
## Usage
[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)
## Source data
[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)
## Citation
[More information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)
+22
View File
@@ -0,0 +1,22 @@
from dataclasses import dataclass
import datasets
@dataclass
class OpenAssistantConfig(datasets.BuilderConfig):
"""BuilderConfig for OpenAssistant datasets."""
name: str = None
version: datasets.Version = None
description: str = None
schema: str = None
subset_id: str = None
lm_features = datasets.Features(
{
"text": datasets.Value("string"),
"meta": [datasets.Value("string")],
}
)
+8
View File
@@ -0,0 +1,8 @@
import typer
def main(output_dir: str = "data"):
"""Download and prepare the dataset for use."""
raise NotImplementedError
if __name__ == "__main__":
typer.run(main)
+205
View File
@@ -0,0 +1,205 @@
# Copyright 2023 The OpenAssistant Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This template serves as a starting point for contributing a dataset to the OpenAssistant repo.
When modifying it for your dataset, look for TODO items that offer specific instructions.
To create a dataset loading script you will create a class and implement 3 methods:
* `_info`: Establishes the schema for the dataset, and returns a datasets.DatasetInfo object.
* `_split_generators`: Downloads and extracts data for each split (e.g. train/val/test) or associates local data with each split.
* `_generate_examples`: Creates examples from data on disk that conform to each schema defined in `_info`.
Full documentation on writing dataset loading scripts can be found here:
https://huggingface.co/docs/datasets/dataset_script
This template is adapted from the one provided by BigScience's BigBIO library:
https://github.com/bigscience-workshop/biomedical/blob/main/templates/template.py
TODO: Before submitting your script, delete this docstring and replace it with a description of your dataset.
"""
import os
from typing import Dict, List, Tuple
import datasets
from .hub import OpenAssistantConfig
# TODO: import the schema (i.e. features) that fits your dataset:
from .hub import
# TODO: Add BibTeX citation where appropriate
_CITATION = """\
@article{,
author = {},
title = {},
journal = {},
volume = {},
year = {},
url = {},
doi = {},
biburl = {},
bibsource = {}
}
"""
# TODO: create a module level variable with your dataset name (should match the script name)
# E.g. The Pile: [dataset_name] --> the_pile
_DATASETNAME = "[dataset_name]"
# TODO: create a pretty display name for your dataset
_DISPLAYNAME = "Dataset Name"
# TODO: Add a description of the dataset here
# You can copy an official description
_DESCRIPTION = """\
This dataset is designed for XXX NLP task.
"""
# TODO: Add a link to an official homepage for the dataset here (if possible)
_HOMEPAGE = ""
# TODO: Add the licence for the dataset here (if possible)
# Note that this doesn't have to be a common open source license.
# Some datasets have custom licenses. In this case, simply put the full license terms
# into `_LICENSE`
_LICENSE = ""
# TODO: Add links to the URLs needed to download your dataset files.
# This variable can be a relative path for datasets whose files need to be
# manually downloaded or preprocessed in advance.
# For publicly available datasets you will most likely end up passing these URLs to dl_manager in _split_generators.
# However, if you need to access different files for each config you can have multiple entries in this dict.
# This can be an arbitrarily nested dict/list of URLs (see below in `_split_generators` method)
_URLS = {
_DATASETNAME: "url or list of urls or relative path like ./data ",
}
# TODO: add supported task by dataset. One dataset may support multiple tasks
_SUPPORTED_TASKS = [] # example: [Tasks.TRANSLATION, Tasks.NAMED_ENTITY_RECOGNITION, Tasks.RELATION_EXTRACTION]
# TODO: set this to a version that is associated with the dataset. if none exists use "1.0.0"
# This version doesn't have to be consistent with semantic versioning. Anything that is
# provided by the original dataset as a version goes.
_VERSION = ""
# TODO: Name the dataset class to match the script name using CamelCase instead of snake_case
# Append "Dataset" to the class name: ThePile --> ThePileDataset
class NewDataset(datasets.GeneratorBasedBuilder):
"""TODO: Short description of my dataset."""
VERSION = datasets.Version(_VERSION)
# You will be able to load each dataset with
# dataset = datasets.load_dataset('my_dataset')
# TODO: For each dataset, implement a config for each subset;
# If a dataset contains more than one subset, implement a config for EACH of them.
# Each of them should contain:
# - name: should be unique for each dataset config eg. the_pile_[schema_name]
# - version: VERSION
# - description: one line description for the dataset
# - schema: open_assistant_[schema_name]
# - subset_id: subset id is the canonical name for the dataset (eg. the_pile)
# where [schema_name] = (language_modeling)
BUILDER_CONFIGS = [
OpenAssistantConfig(
name=f"{_DATASETNAME}_[schema_name]",
version=VERSION,
description=f"OpenAssistant dataset config for {_DATASETNAME}",
schema_name="[schema_name]",
subset_id=_DATASETNAME,
)
]
DEFAULT_CONFIG_NAME = _DATASETNAME
def _info(self) -> datasets.DatasetInfo:
# TODO: Implement the schema for your dataset here.
raise NotImplementedError()
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager) -> List[datasets.SplitGenerator]:
"""Returns SplitGenerators."""
# TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
# If you need to access a config choice, that will be in self.config.name
# dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLs; many examples use the download_and_extract method; see the DownloadManager docs here: https://huggingface.co/docs/datasets/package_reference/builder_classes.html#datasets.DownloadManager
# dl_manager can accept any type of nested list/dict and will give back the same structure with the url replaced with the path to local files.
urls = _URLS[_DATASETNAME]
data_dir = dl_manager.download_and_extract(urls)
# Not all datasets have predefined canonical train/val/test splits.
# If your dataset has no predefined splits, use datasets.Split.TRAIN for all of the data.
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# Whatever you put in gen_kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "train.jsonl"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(data_dir, "test.jsonl"),
"split": "test",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(data_dir, "dev.jsonl"),
"split": "dev",
},
),
]
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
# TODO: change the args of this function to match the keys in `gen_kwargs`. You may add any necessary kwargs.
def _generate_examples(self, filepath, split: str) -> Tuple[int, Dict]:
"""Yields examples as (key, example) tuples."""
# TODO: This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
# The `key` is for legacy reasons (tfds) and is not important in itself, but must be unique for each example.
# NOTE: For local datasets you will have access to self.config.data_dir and self.config.data_files
if self.config.schema == "[schema_name]":
# TODO: yield (key, example) tuples in the given schema
for key, example in thing:
yield key, example
# This allows you to run your dataloader with `python [dataset_name].py` during development
# TODO: Remove this before making your PR
if __name__ == "__main__":
datasets.load_dataset(__file__)
+11
View File
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to website directory
pushd "$parent_path/../../website"
set -xe
npm run cypress:run:contract
popd
+21 -20
View File
@@ -1,9 +1,11 @@
import numpy as np
from scipy import log2
from scipy.integrate import nquad
from scipy.special import gammaln, psi
from scipy.stats import dirichlet
'''
Legacy numerical solution.
Should not be used as it is probably broken
def make_range(*x):
"""
@@ -38,6 +40,23 @@ def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
res = fun(pos)
return np.mean(res)
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
'''
def analytic_solution(a_post, a_prior):
"""
@@ -57,26 +76,8 @@ def analytic_solution(a_post, a_prior):
return info
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
def uniform_expected_infogain(a_prior):
mean_weight = dirichlet.mean(a_prior)
print("weight", mean_weight)
results = []
for i, w in enumerate(mean_weight):
a_post = a_prior.copy()
+6 -4
View File
@@ -87,8 +87,9 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the gain of points for a given prompt's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
In contrast to the other score updating functions, we can run this online as new votes come in.
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information.
Parameters:
consensus (ArrayLike): all votes cast for this question
@@ -100,7 +101,8 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
# since 100 is the lowest, 300 the highest and 200 the middle value
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
delta_votes = np.sum(consensus_ranking * consensus)
# expected consenus ranking (i.e. normalize the votes and multiply-sum with weightings)
delta_votes = np.sum(consensus_ranking * consensus / sum(consensus))
new_points = delta_votes + voter_data.prompt_points
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
@@ -133,7 +135,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr
"research design and statistical analyses, second edition, 2003"
the authors note that at least from an significance test POV they will yield the same p-values
Parameters:
Parameters:
user_ranking (ArrayLike): ranking produced by the user
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
+75
View File
@@ -0,0 +1,75 @@
from enum import Enum
import numpy as np
from scipy import optimize
class Task(Enum):
RANKING = 0
ANSWER = 1
PROMPT = 2
VOTE = 3
def task_selection(
num_ranking_tasks: int, current_prompts: int, target_num_prompts: int, p: float, answers_per_prompt: int
) -> Task:
"""
This computes which task to serve to the user.
In general, this method aims to get rankable tasks out of the active pool ASAP.
Before checking anything else, we first have a p% probability of running a ranking task.
After that, we can dynamically determine which task to serve by balancing the number of active tasks.
Parameters:
num_ranking_tasks (int): number of prompts that are ready to do ranking (i.e. have "answers_per_prompt" many answers)
current_prompts (int): how many prompts are currently in the active pool
target_num_prompts (int): how many prompts _should_ be in the active pool
p (float): probability to serve a ranking task, if one is available
answers_per_prompt (int): number of answers we want to have per prompt
Returns:
task (Task): the task Enum that corresponds to one of the four tasks
"""
if num_ranking_tasks > 0 and np.random.rand() < p:
return Task.RANKING
rate = 50 / (current_prompts * 2)
prob_prompt_task = 0.5 + (target_num_prompts - current_prompts) * rate
# Yes, I'm too lazy to solve this analytically...
prob_unfinished_prompt = optimize.linprog(
np.array([1, 1]), A_eq=np.array([[1, 1], [1, -answers_per_prompt]]), b_eq=np.array([1, 0]), bounds=(0, None)
).x[0]
if np.random.rand() < prob_prompt_task:
if np.random.rand() < prob_unfinished_prompt:
return Task.ANSWER
else:
return Task.PROMPT
else:
return Task.VOTE
def next_answer_task(possible_prompts, answers_per_prompt):
"""
If the `task_selection`method returns "answer", you can use this method to decide which
prompt should get an answer next.
The goal of this is to finish off the prompts that have almost enough answers collected already:
I.e. if we want 5 answers, this is going to give preferential sampling to those prompts that already
have 4/5 answers.
This helps to not have too much close-to-finished prompts in the active set.
Parameters:
possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have.
answers_per_prompt (int): number of answers we per prompt to target
Returns:
prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer
"""
nums = list(set(possible_prompts.values()))
p = np.array([max(x / answers_per_prompt, 1 / answers_per_prompt) for x in nums])
idx = np.random.choice(nums, p=p / p.sum())
sample = np.random.choice([k for k, v in possible_prompts.items() if v == idx])
return sample
if __name__ == "__main__":
x = task_selection(1, 500, 1000, 0.1, 5)
print(x)
y = next_answer_task({"this": 2, "is": 4, "a": 1, "test": 4}, 5)
print(y)
+2
View File
@@ -1,3 +1,5 @@
ADMIN_USERS = "credentials:admin,discord:root,email:admin@example.com"
# The database created by running the jobs in /scripts/frontend-development/docker-compose.yaml
DATABASE_URL=postgres://postgres:postgres@localhost:5433/oasst_web
+1 -1
View File
@@ -53,7 +53,7 @@ If you're doing active development we suggest the following workflow:
1. Run `docker compose up frontend-dev --build --attach-dependencies`. You can
optionally include `-d` to detach and later track the logs if desired.
1. In another tab navigate to `${OPEN_ASSISTANT_ROOT/website`.
1. Run `npm install`
1. Run `npm ci`
1. Run `npx prisma db push` (This is also needed when you restart the docker
stack from scratch).
1. Run `npm run dev`. Now the website is up and running locally at
+9
View File
@@ -0,0 +1,9 @@
import { defineConfig } from "cypress";
export default defineConfig({
e2e: {
// No baseUrl here, because we don't need it for contract testing
baseUrl: null,
specPattern: "cypress/contract/*.cy.{ts,js}",
},
});
@@ -0,0 +1,50 @@
import { OasstApiClient } from "src/lib/oasst_api_client";
describe("Contract test for Oasst API", function () {
// Assumes this is running the mock server.
const oasstApiClient = new OasstApiClient("http://localhost:8080", "test");
it("can fetch a task", async () => {
expect(
await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
})
).to.be.not.null;
});
it("can ack a task", async () => {
const task = await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
});
expect(await oasstApiClient.ackTask(task.id, "321")).to.be.null;
});
it("can record a taskInteraction", async () => {
const task = await oasstApiClient.fetchTask("random", {
sub: "test",
name: "test",
email: "test",
});
expect(
await oasstApiClient.interactTask(
"text_reply_to_message",
task.id,
"1",
{ text: "Test" },
{
sub: "test",
name: "test",
email: "test",
}
)
).to.be.not.null;
});
// TODO(#354): Add test for 204
// TODO(#354): Add test for parsing >=300, throwing an OasstError
// TODO(#354): Add test for parsing >=300, throwing a generic error
});
+8
View File
@@ -2,6 +2,14 @@
const nextConfig = {
output: "standalone",
reactStrictMode: true,
images: {
remotePatterns: [
{
protocol: "https",
hostname: "**.discordapp.com",
},
],
},
experimental: {
/* Disabling this for now only because it causes a warning in the console that cannot be silenced for eslint
If this can be resolved, we should re-enable this.
+7 -7
View File
@@ -39,7 +39,7 @@
"react": "18.2.0",
"react-dom": "18.2.0",
"react-icons": "^4.7.1",
"sharp": "^0.31.3",
"sharp": "0.31.2",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"use-debounce": "^9.0.2"
@@ -27907,9 +27907,9 @@
}
},
"node_modules/sharp": {
"version": "0.31.3",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.3.tgz",
"integrity": "sha512-XcR4+FCLBFKw1bdB+GEhnUNXNXvnt0tDo4WsBsraKymuo/IAuPuCBVAL2wIkUw2r/dwFW5Q5+g66Kwl2dgDFVg==",
"version": "0.31.2",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.2.tgz",
"integrity": "sha512-DUdNVEXgS5A97cTagSLIIp8dUZ/lZtk78iNVZgHdHbx1qnQR7JAHY0BnXnwwH39Iw+VKhO08CTYhIg0p98vQ5Q==",
"hasInstallScript": true,
"dependencies": {
"color": "^4.2.3",
@@ -52121,9 +52121,9 @@
}
},
"sharp": {
"version": "0.31.3",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.3.tgz",
"integrity": "sha512-XcR4+FCLBFKw1bdB+GEhnUNXNXvnt0tDo4WsBsraKymuo/IAuPuCBVAL2wIkUw2r/dwFW5Q5+g66Kwl2dgDFVg==",
"version": "0.31.2",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.31.2.tgz",
"integrity": "sha512-DUdNVEXgS5A97cTagSLIIp8dUZ/lZtk78iNVZgHdHbx1qnQR7JAHY0BnXnwwH39Iw+VKhO08CTYhIg0p98vQ5Q==",
"requires": {
"color": "^4.2.3",
"detect-libc": "^2.0.1",
+2 -1
View File
@@ -12,6 +12,7 @@
"build-storybook": "build-storybook",
"cypress": "cypress open",
"cypress:run": "cypress run",
"cypress:run:contract": "cypress run --config-file ./cypress.config.contract.js",
"cypress:image-baseline": "cypress-image-diff -u",
"fix:lint": "eslint --fix src/ --ext .js,.jsx,.ts,.tsx",
"fix:format": "prettier --write ./src",
@@ -49,7 +50,7 @@
"react": "18.2.0",
"react-dom": "18.2.0",
"react-icons": "^4.7.1",
"sharp": "^0.31.3",
"sharp": "0.31.2",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
"use-debounce": "^9.0.2"
+1
View File
@@ -41,6 +41,7 @@ model User {
email String? @unique
emailVerified DateTime?
image String?
role String @default("general")
accounts Account[]
sessions Session[]
@@ -0,0 +1,28 @@
import { Button, useDisclosure } from "@chakra-ui/react";
import { Modal, ModalOverlay, ModalContent, ModalHeader, ModalBody, ModalCloseButton } from "@chakra-ui/react";
import React from "react";
export const CollapsableText = ({ text, maxLength = 220 }) => {
const { isOpen, onOpen, onClose } = useDisclosure();
if (typeof text != "string" || text.length <= maxLength) {
return text;
} else {
return (
<>
{text.substring(0, maxLength - 3)}
<Button style={{ display: "contents" }} onClick={onOpen}>
...
</Button>
<Modal isOpen={isOpen} onClose={onClose} size="xl" scrollBehavior={"inside"}>
<ModalOverlay style={{ width: "100%", height: "100%" }}>
<ModalContent maxH="400">
<ModalHeader>Full Text</ModalHeader>
<ModalCloseButton />
<ModalBody>{text}</ModalBody>
</ModalContent>
</ModalOverlay>
</Modal>
</>
);
}
};
+1 -1
View File
@@ -1,5 +1,5 @@
import clsx from "clsx";
export function Container({ className, ...props }) {
return <div className={clsx("mx-auto max-w-7xl px-4 sm:px-6 lg:px-8", className)} {...props} />;
return <div className={clsx("mx-auto max-w-7xl px-4", className)} {...props} />;
}
@@ -0,0 +1,17 @@
import { Box } from "@chakra-ui/react";
import { Message } from "./Messages";
export const ContextMessages = ({ messages }: { messages: Message[] }) => {
return (
<Box className="flex flex-col gap-1">
{messages.map((message, i) => {
return (
<Box key={i}>
<span>{message.is_assistant ? "Assistant: " : "User: "}</span>
<span>{message.text}</span>
</Box>
);
})}
</Box>
);
};
@@ -0,0 +1,94 @@
import { Badge, Box, Image, Link, Stack, StackDivider, Text, useColorModeValue } from "@chakra-ui/react";
export function LeaderboardTable() {
const backgroundColor = useColorModeValue("white", "gray.700");
const accentColor = useColorModeValue("gray.200", "gray.900");
//need to add streak info to chart
const leaderInfo = [
{
name: "fozziethebeat#6690",
image: "/images/temp-avatars/av1.jpg",
score: "5,208",
arrowDir: "increase",
streak: false,
streakCount: "5-Day Streak",
},
{
name: "k_nearest_neighbor#8579",
image: "/images/temp-avatars/av2.jpg",
score: "5,164",
arrowDir: "decrease",
streak: false,
streakCount: "",
},
{
name: "andreaskoepf#2266",
image: "/images/temp-avatars/av3.jpg",
score: "5,120",
arrowDir: "",
streak: false,
streakCount: "2-Day Streak",
},
{
name: "AbdBarho#1684",
image: "/images/temp-avatars/av4.jpg",
score: "4,260",
arrowDir: "",
streak: false,
streakCount: "",
},
{
name: "zu#9016",
image: "/images/temp-avatars/av5.jpg",
score: "3,608",
arrowDir: "",
streak: false,
streakCount: "",
},
];
return (
<main className="h-fit col-span-3">
<div className="flex flex-col gap-4">
<div className="flex items-end justify-between">
<Text className="text-2xl font-bold">Top 5 Contributors</Text>
<Link href="#" _hover={{ textDecoration: "none" }}>
<Text color="blue.400" className="text-sm font-bold">
View All -&gt;
</Text>
</Link>
</div>
<Box
backgroundColor={backgroundColor}
boxShadow="base"
dropShadow={accentColor}
borderRadius="xl"
className="p-6 shadow-sm"
>
<Stack divider={<StackDivider />} spacing="4">
<div className="grid grid-cols-4 items-center font-bold">
<p>Name</p>
<div className="col-start-4 flex justify-center">
<p>Score</p>
</div>
</div>
{leaderInfo.map((item, itemIndex) => (
<div key={itemIndex} className="grid grid-cols-4 items-center">
<div className="flex items-center gap-3">
<Image alt="Profile Picture" src={item.image} boxSize="7" borderRadius="full"></Image>
<p>{item.name}</p>
<Badge colorScheme="purple">{item.streakCount}</Badge>
</div>
<Box bg={backgroundColor} className="col-start-4 flex justify-center">
<p>{item.score}</p>
</Box>
</div>
))}
</Stack>
</Box>
</div>
</main>
);
}
@@ -0,0 +1,89 @@
import { Box, Button, Link, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { useRouter } from "next/router";
import { FiLayout, FiSun, FiMessageSquare } from "react-icons/fi";
import { colors } from "styles/Theme/colors";
export function SideMenu() {
const router = useRouter();
const { colorMode, toggleColorMode } = useColorMode();
const buttonOptions = [
{
label: "Dashboard",
pathname: "/dashboard",
desc: "Dashboard Home",
icon: FiLayout,
},
{
label: "Messages",
pathname: "/messages",
desc: "Messages Dashboard",
icon: FiMessageSquare,
},
// {
// label: "Leaderboard",
// pathname: "#",
// desc: "Public Leaderboard",
// icon: FiAward,
// },
// {
// label: "Stats",
// pathname: "#",
// desc: "User Statistics",
// icon: FiBarChart2,
// },
];
return (
<main className="sticky top-0 sm:h-full">
<Box
width={["100%", "100%", "100px", "280px"]}
backgroundColor={colorMode === "light" ? colors.light.div : colors.dark.div}
boxShadow="base"
borderRadius="xl"
className="grid grid-cols-4 gap-2 sm:flex sm:flex-col sm:justify-between p-4 h-full"
>
<nav className="grid grid-cols-3 col-span-3 sm:flex sm:flex-col gap-2">
{buttonOptions.map((item, itemIndex) => (
<Tooltip
key={itemIndex}
fontFamily="inter"
label={item.label}
placement="right"
className="hidden lg:hidden sm:block"
>
<Link key={`${item.label}-${itemIndex}`} href={item.pathname} style={{ textDecoration: "none" }}>
<Button
justifyContent={["center", "center", "center", "left"]}
gap="3"
size="lg"
width="full"
bg={router.pathname === item.pathname ? "blue.500" : null}
_hover={router.pathname === item.pathname ? { bg: "blue.600" } : null}
>
<item.icon className={router.pathname === item.pathname ? "text-blue-200" : null} />
<Text
fontWeight="normal"
color={router.pathname === item.pathname ? "white" : null}
className="hidden lg:block"
>
{item.label}
</Text>
</Button>
</Link>
</Tooltip>
))}
</nav>
<div>
<Tooltip fontFamily="inter" label="Toggle Dark Mode" placement="right" className="hidden lg:hidden sm:block">
<Button size="lg" width="full" justifyContent="center" onClick={toggleColorMode} gap="2">
<FiSun />
<Text fontWeight="normal" className="hidden lg:block">
{colorMode === "light" ? "Dark Mode" : "Light Mode"}
</Text>
</Button>
</Tooltip>
</div>
</Box>
</main>
);
}
@@ -0,0 +1,126 @@
import { Box, Flex, GridItem, Heading, SimpleGrid, Text, useColorModeValue } from "@chakra-ui/react";
import Link from "next/link";
const crTasks = [
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
type: "create",
pathname: "/create/initial_prompt",
},
{
label: "Reply as User",
desc: "Chat with Open Assistant and help improve its responses as you interact with it.",
type: "create",
pathname: "/create/user_reply",
},
{
label: "Reply as Assistant",
desc: "Help Open Assistant improve its responses to conversations with other users.",
type: "create",
pathname: "/create/assistant_reply",
},
];
const evTasks = [
{
label: "Rank User Replies",
type: "eval",
desc: "Help Open Assistant improve its responses to conversations with other users.",
pathname: "/evaluate/rank_user_replies",
},
{
label: "Rank Assistant Replies",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
type: "eval",
pathname: "/evaluate/rank_assistant_replies",
},
{
label: "Rank Initial Prompts",
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
type: "eval;",
pathname: "/evaluate/rank_initial_prompts",
},
];
export const TaskOption = () => {
const backgroundColor = useColorModeValue("white", "gray.700");
return (
<Box className="flex flex-col gap-14" fontFamily="inter">
<div>
<Text className="text-2xl font-bold pb-4">Create</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{crTasks.map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
boxShadow="base"
className="flex flex-col justify-between h-full"
>
<Box className="p-6 pb-10">
<Flex flexDir="column" gap="3">
<Heading size="md" fontFamily="inter">
{item.label}
</Heading>
<Text size="sm" opacity="80%">
{item.desc}
</Text>
</Flex>
</Box>
<Box
bg="blue.500"
borderBottomRadius="xl"
className="px-6 py-2 transition-colors duration-300"
_hover={{ backgroundColor: "blue.600" }}
>
<Text fontWeight="bold" color="white">
Go
</Text>
</Box>
</GridItem>
</Link>
))}
</SimpleGrid>
</div>
<div>
<Text className="text-2xl font-bold pb-4">Evaluate</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{evTasks.map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
borderRadius="xl"
boxShadow="base"
className="flex flex-col justify-between h-full"
>
<Box className="p-6 pb-10">
<Flex flexDir="column" gap="3">
<Heading size="md" fontFamily="inter">
{item.label}
</Heading>
<Text size="sm" opacity="80%">
{item.desc}
</Text>
</Flex>
</Box>
<Box
bg="blue.500"
borderBottomRadius="xl"
className="px-6 py-2 transition-colors duration-300"
_hover={{ backgroundColor: "blue.600" }}
>
<Text fontWeight="bold" color="white">
Go
</Text>
</Box>
</GridItem>
</Link>
))}
</SimpleGrid>
</div>
</Box>
);
};
@@ -0,0 +1,3 @@
export { LeaderboardTable } from "./LeaderboardTable";
export { SideMenu } from "./SideMenu";
export { TaskOption } from "./TaskOption";
+10 -5
View File
@@ -15,6 +15,7 @@ import {
SliderThumb,
SliderTrack,
Spacer,
Tooltip,
useBoolean,
useId,
} from "@chakra-ui/react";
@@ -69,11 +70,15 @@ export const FlaggableElement = (props) => {
>
<Grid templateColumns="1fr min-content" gap={2}>
<PopoverAnchor>{props.children}</PopoverAnchor>
<PopoverTrigger>
<Button h="full">
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
</Button>
</PopoverTrigger>
<Tooltip hasArrow label="Report" bg="red.600">
<div>
<PopoverTrigger>
<Button h="full">
<FlagIcon className="w-4 text-gray-400 group-hover:text-gray-500" aria-hidden="true" />
</Button>
</PopoverTrigger>
</div>
</Tooltip>
</Grid>
<PopoverContent width="fit-content">
+10 -83
View File
@@ -1,43 +1,11 @@
import { Box, Button, useColorMode } from "@chakra-ui/react";
import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import { Box, Button, Text, useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { FaUser } from "react-icons/fa";
import { ColorModeIconToggle } from "../UI/ColorModeIconToggle";
import { UserMenu } from "./UserMenu";
function MenuIcon(props) {
return (
<svg viewBox="0 0 24 24" fill="none" aria-hidden="true" {...props}>
<path d="M5 6h14M5 18h14M5 12h14" strokeWidth={2} strokeLinecap="round" strokeLinejoin="round" />
</svg>
);
}
function ChevronUpIcon(props) {
return (
<svg viewBox="0 0 24 24" fill="none" aria-hidden="true" {...props}>
<path d="M17 14l-5-5-5 5" strokeWidth={2} strokeLinecap="round" strokeLinejoin="round" />
</svg>
);
}
function MobileNavLink({ children, ...props }) {
return (
<Popover.Button
as={Link}
href={props.href}
className="block text-base leading-7 tracking-tight text-gray-700"
{...props}
>
{children}
</Popover.Button>
);
}
function AccountButton() {
const { data: session } = useSession();
if (session) {
@@ -53,70 +21,29 @@ function AccountButton() {
}
export function Header(props) {
const { data: session } = useSession();
const homeURL = session ? "/dashboard" : "/";
const { colorMode } = useColorMode();
const borderClass = props.transparent
? ""
: colorMode === "light"
? "border-b border-gray-400"
: "border-b border-zinc-800";
return (
<nav className={`oa-basic-theme ${borderClass}`}>
<Box className="flex mx-auto max-w-7xl justify-between py-8 px-10">
<div className="relative z-10 flex items-center gap-16">
<Link href="/" aria-label="Home" className="flex items-center">
<Box className="relative z-10 flex justify-between px-4 py-4">
<div className="relative z-10 flex items-center gap-10">
<Link href={homeURL} aria-label="Home" className="flex items-center">
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
<span className="text-2xl font-bold ml-3">Open Assistant</span>
<Text fontFamily="inter" fontSize="2xl" fontWeight="bold" className="ml-3">
Open Assistant
</Text>
</Link>
</div>
<div className="flex items-center gap-4">
<Popover className="lg:hidden">
{({ open }) => (
<>
<Popover.Button
className="relative z-10 inline-flex items-center rounded-lg stroke-gray-900 p-2 hover:bg-gray-200/50 hover:stroke-gray-600 active:stroke-gray-900 [&:not(:focus-visible)]:focus:outline-none"
aria-label="Toggle site navigation"
>
{({ open }) => (open ? <ChevronUpIcon className="h-6 w-6" /> : <MenuIcon className="h-6 w-6" />)}
</Popover.Button>
<AnimatePresence initial={false}>
{open && (
<>
<Popover.Overlay
static
as={motion.div}
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
className="fixed inset-0 z-1 bg-gray-300/60 backdrop-blur"
/>
<Popover.Panel
static
as={motion.div}
initial={{ opacity: 0, y: -32 }}
animate={{ opacity: 1, y: 0 }}
exit={{
opacity: 0,
y: -32,
transition: { duration: 0.2 },
}}
className="absolute inset-x-0 top-0 z-0 origin-top rounded-b-2xl bg-white px-6 pb-6 pt-32 shadow-2xl shadow-gray-900/20"
>
<div className="space-y-4">
<MobileNavLink href="/#join-us">Join Us</MobileNavLink>
<MobileNavLink href="/#faqs">FAQs</MobileNavLink>
</div>
<div className="mt-8 flex flex-col gap-4"></div>
</Popover.Panel>
</>
)}
</AnimatePresence>
</>
)}
</Popover>
<AccountButton />
<UserMenu />
<ColorModeIconToggle className="ml-5" />
</div>
</Box>
</nav>
+6 -3
View File
@@ -1,7 +1,8 @@
import { useColorMode } from "@chakra-ui/react";
import { Text, useColorMode } from "@chakra-ui/react";
import { AnimatePresence, motion } from "framer-motion";
import Link from "next/link";
import { useState } from "react";
import { colors } from "styles/Theme/colors";
export function NavLinks(): JSX.Element {
const [hoveredIndex, setHoveredIndex] = useState(null);
@@ -14,8 +15,8 @@ export function NavLinks(): JSX.Element {
return (
<>
{[
["Join Us", "/#join-us"],
["FAQ", "/#faq"],
["Join Us", "/#join-us"],
].map(([label, href], index) => (
<Link
key={label}
@@ -38,7 +39,9 @@ export function NavLinks(): JSX.Element {
/>
)}
</AnimatePresence>
<span className="relative z-10">{label}</span>
<Text color={colorMode === "light" ? colors.light.text : colors.dark.text} className="relative z-10">
{label}
</Text>
</Link>
))}
</>
+54 -32
View File
@@ -1,25 +1,33 @@
import { Box, useColorModeValue } from "@chakra-ui/react";
import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
import { Popover } from "@headlessui/react";
import { AnimatePresence, motion } from "framer-motion";
import Image from "next/image";
import { signOut, useSession } from "next-auth/react";
import React from "react";
import { FaCog, FaSignOutAlt } from "react-icons/fa";
import { FiLayout, FiLogOut, FiSettings } from "react-icons/fi";
export function UserMenu() {
const { data: session } = useSession();
const backgroundColor = useColorModeValue("#FFFFFF", "#000000");
const backgroundColor = useColorModeValue("white", "gray.700");
const accentColor = useColorModeValue("gray.300", "gray.600");
if (!session) {
return <></>;
}
if (session && session.user) {
const accountOptions = [
{
name: "Dashboard",
href: "/dashboard",
desc: "Dashboard",
icon: FiLayout,
//For future use
},
{
name: "Account Settings",
href: "/account",
desc: "Account Settings",
icon: FaCog,
icon: FiSettings,
//For future use
},
];
@@ -28,18 +36,22 @@ export function UserMenu() {
{({ open }) => (
<>
<Popover.Button aria-label="Toggle Account Options" className="flex">
<div className="flex items-center gap-4 p-1 lg:pr-6 rounded-full border border-slate-300/70 hover:bg-gray-200/50 transition-colors duration-300">
<Box
borderWidth="1px"
borderColor={accentColor}
className="flex items-center gap-4 p-1 lg:pr-6 rounded-full transition-colors duration-300"
>
<Image
src="/images/temp-avatars/av1.jpg"
src={session.user.image || "/images/temp-avatars/av1.jpg"}
alt="Profile Picture"
width="40"
height="40"
width="36"
height="36"
className="rounded-full"
></Image>
<p data-cy="username" className="hidden lg:flex">
{session.user.name || session.user.email}
</p>
</div>
</Box>
</Popover.Button>
<AnimatePresence initial={false}>
{open && (
@@ -54,35 +66,45 @@ export function UserMenu() {
y: -10,
transition: { duration: 0.2 },
}}
className="absolute right-0 mt-3 w-screen bg-inherit max-w-xs p-4 rounded-md border border-slate-300/70"
>
<Box className="flex flex-col gap-1">
{accountOptions.map((item) => (
<a
key={item.name}
href={item.href}
aria-label={item.desc}
className="flex items-center rounded-md hover:bg-gray-200/50"
<Box
bg={backgroundColor}
borderWidth="1px"
borderColor={accentColor}
borderRadius="xl"
className="absolute right-0 mt-3 w-screen max-w-xs p-4"
>
<Box className="flex flex-col gap-1">
{accountOptions.map((item) => (
<Link
key={item.name}
href={item.href}
aria-label={item.desc}
className="flex items-center"
bg={backgroundColor}
_hover={{ textDecoration: "none" }}
>
<div className="p-4">
<item.icon className="text-blue-500" aria-hidden="true" />
</div>
<div>
<Text fontFamily="inter">{item.name}</Text>
</div>
</Link>
))}
<Link
className="flex items-center rounded-md cursor-pointer"
_hover={{ textDecoration: "none" }}
onClick={() => signOut({ callbackUrl: "/" })}
>
<div className="p-4">
<item.icon aria-hidden="true" />
<FiLogOut className="text-blue-500" />
</div>
<div>
<p>{item.name}</p>
<Text fontFamily="inter">Sign Out</Text>
</div>
</a>
))}
<a
className="flex items-center rounded-md hover:bg-gray-100 cursor-pointer"
onClick={() => signOut({ callbackUrl: "/" })}
>
<div className="p-4">
<FaSignOutAlt />
</div>
<div>
<p>Sign Out</p>
</div>
</a>
</Link>
</Box>
</Box>
</Popover.Panel>
</Box>
+7
View File
@@ -25,4 +25,11 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => (
</div>
);
export const getDashboardLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
{page}
</div>
);
export const noLayout = (page: React.ReactElement) => page;
@@ -0,0 +1,12 @@
import { Box, CircularProgress, Stack, StackDivider, useColorModeValue } from "@chakra-ui/react";
import { MessageTableEntry } from "./MessageTableEntry";
export function MessageTable({ messages }) {
return (
<Stack divider={<StackDivider />} spacing="4">
{messages.map((item, idx) => (
<MessageTableEntry item={item} idx={idx} key={item.id} />
))}
</Stack>
);
}
@@ -0,0 +1,24 @@
import { Avatar, Box, HStack, LinkBox, useColorModeValue } from "@chakra-ui/react";
import { boolean } from "boolean";
import NextLink from "next/link";
import { FlaggableElement } from "../FlaggableElement";
export function MessageTableEntry({ item, idx }) {
const bgColor = useColorModeValue(idx % 2 === 0 ? "bg-slate-800" : "bg-black", "bg-sky-900");
return (
<FlaggableElement text={item.text} post_id={item.id} key={`flag_${item.id}`}>
<HStack>
<Avatar
name={`${boolean(item.is_assistant) ? "Assitant" : "User"}`}
src={`${boolean(item.is_assistant) ? "/images/logos/logo.png" : "/images/temp-avatars/av1.jpg"}`}
/>
<LinkBox className={`p-4 rounded-md text-white whitespace-pre-wrap ${bgColor} text-white w-full`}>
<NextLink href={`/messages/${item.id}`} passHref>
{item.text}
</NextLink>
</LinkBox>
</HStack>
</FlaggableElement>
);
}
+78
View File
@@ -0,0 +1,78 @@
import { Container } from "./Container";
const Roadmap = () => {
return (
<Container className="">
<div className="py-32">
<h2 className="text-4xl mb-16">Our Roadmap</h2>
<div className="flex flex-col items-center space-y-8 md:space-y-0 md:items-start md:flex-row md:justify-between">
<div className="flex flex-col items-center space-y-4">
<div className="h-[5rem] w-[5rem] border-4 border-[#a72a1e] rounded-full flex items-center justify-center">
<p className="font-bold text-[#a72a1e] text-center">ASAP</p>
</div>
<h4 className="font-bold text-xl text-[#a72a1e] text-center max-w-[10rem]">Minimum Viable Prototype</h4>
<ul className="ml-6 md:ml-8 lg:ml-6 space-y-4 text-[#a72a1e] list-disc">
<li>Data Collection Pipeline</li>
<li>RL on Human Feedback</li>
<li>Assistant v1 usable</li>
<li>Out January 2023!</li>
</ul>
</div>
<div>
<span className="w-[4vw] h-[4px] mt-8 bg-[#a72a1e] rounded-full hidden md:block" />
</div>
<span className="w-[4px] h-16 bg-[#a72a1e] rounded-full block md:hidden" />
<div className="flex flex-col items-center space-y-4">
<div className="h-[5rem] w-[5rem] border-4 border-[#858585] rounded-full flex items-center justify-center">
<p className="font-bold text-[#858585] text-center">
Q1
<br />
2023
</p>
</div>
<h4 className="font-bold text-xl text-[#858585] text-center max-w-[10rem]">Growing Up</h4>
<ul className="ml-6 md:ml-8 lg:ml-6 space-y-4 text-[#858585] list-disc">
<li>Retrieval Augmentation</li>
<li>Rapid Personalization</li>
<li>Using External Tools</li>
</ul>
</div>
<div>
<span className="w-[4vw] h-[4px] mt-8 bg-[#858585] rounded-full hidden md:block" />
</div>
<span className="w-[4px] h-16 bg-[#858585] rounded-full block md:hidden" />
<div className="flex flex-col items-center space-y-4">
<div className="h-[5rem] w-[5rem] border-4 border-[#858585] rounded-full flex items-center justify-center">
<p className="font-bold text-[#858585] text-center">
Q2
<br />
2023
</p>
</div>
<h4 className="font-bold text-xl text-[#858585] text-center max-w-[10rem]">Growing Up</h4>
<ul className="ml-6 md:ml-8 lg:ml-6 space-y-4 text-[#858585] list-disc">
<li>Third-Party Extentions</li>
<li>Device Control</li>
<li>Multi-Modality</li>
</ul>
</div>
<div>
<span className="w-[4vw] h-[4px] mt-8 bg-[#858585] rounded-full hidden md:block" />
</div>
<span className="w-[4px] h-16 bg-[#858585] rounded-full block md:hidden" />
<div className="flex flex-col items-center space-y-4">
<div className="h-[5rem] w-[5rem] border-4 border-[#858585] rounded-full flex items-center justify-center">
<p className="font-bold text-[#858585] text-center">...</p>
</div>
<h4 className="font-bold text-xl text-[#858585] text-center max-w-[10rem]">Growing Up</h4>
<ul className="ml-6 md:ml-8 lg:ml-6 space-y-4 text-[#858585] list-disc">
<li>What do you need?</li>
</ul>
</div>
</div>
</div>
</Container>
);
};
export default Roadmap;
+50
View File
@@ -0,0 +1,50 @@
import { Container } from "./Container";
const Services = () => {
return (
<div className="bg-white py-32 border-t-[1px] border-gray-150">
<Container className="">
<div className="grid grid-cols-1 lg:grid-cols-10 lg:grid-rows-2">
<div className="grid col-span-3 row-span-2">
<div className="min-h-[17.5rem] px-12 sm:px-16 row-span-1 flex flex-col justify-center bg-[#307bf3] rounded-tl-[45px] lg:rounded-tl-none lg:rounded-bl-[45px] rounded-br-[45px] lg:rounded-br-none lg:rounded-tr-[45px]">
<h4 className="font-bold text-white text-xl mb-4">Your Conversational Assistant</h4>
<span className="w-8 h-[2px] bg-white mb-4 block lg:hidden" />
<p className="text-white">State-of-the-Art chat assistant that can be personalized to your needs</p>
</div>
<div className="min-h-[17.5rem] px-12 sm:px-16 row-span-1 flex flex-col justify-center bg-[#275ddf] rounded-tl-[45px] rounded-br-[45px]">
<h4 className="font-bold text-white text-xl mb-4">Interface w/ external systems</h4>
<span className="w-8 h-[2px] bg-white mb-4 block lg:hidden" />
<p className="text-white">
Usage of APIs and third-party applications, described via language & demonstrations.
</p>
</div>
</div>
<div className="grid grid-rows-2 col-span-3 row-span-2">
<div className="min-h-[17.5rem] px-12 sm:px-16 row-span-1 flex flex-col justify-center bg-[#307bf3] lg:bg-[#275ddf] rounded-tl-[45px] rounded-br-[45px]">
<h4 className="font-bold text-white text-xl mb-4">Retrieval via Search Engines</h4>
<span className="w-8 h-[2px] bg-white mb-4 block lg:hidden" />
<p className="text-white">External, upgradeable knowledge: No need for billions of parameters.</p>
</div>
<div className="min-h-[17.5rem] px-12 sm:px-16 row-span-1 flex flex-col justify-center bg-[#275ddf] lg:bg-[#307bf3] rounded-tl-[45px] lg:rounded-tl-none lg:rounded-bl-[45px] rounded-br-[45px] lg:rounded-br-none lg:rounded-tr-[45px]">
<h4 className="font-bold text-white text-xl mb-4">A building block for developers</h4>
<span className="w-8 h-[2px] bg-white mb-4 block lg:hidden" />
<p className="text-white">Integrate OpenAssistant into your application.</p>
</div>
</div>
<div className="px-12 sm:px-16 py-20 lg:p-20 col-span-4 row-span-2 bg-[#1a44a1] lg:flex lg:flex-col lg:justify-center rounded-tl-[45px] rounded-br-[45px] lg:rounded-tr-[80px] lg:rounded-tl-none lg:rounded-br-none">
<h4 className="font-bold text-white text-xl mb-4">OpenAssistant unifies all knowledge work in one place</h4>
<span className="w-8 h-[2px] bg-white mb-4 block lg:hidden" />
<ul className="ml-4 sm:ml-12 mt-8 space-y-4 list-disc text-white">
<li>Uses modern deep learning</li>
<li>Runs on consumer hardware</li>
<li>Trains on human feedback</li>
<li>Free and open</li>
</ul>
</div>
</div>
</Container>
</div>
);
};
export default Services;
+2 -1
View File
@@ -18,6 +18,7 @@ import {
} from "@dnd-kit/sortable";
import { ReactNode, useEffect, useState } from "react";
import { CollapsableText } from "../CollapsableText";
import { SortableItem } from "./SortableItem";
export interface SortableProps {
@@ -64,7 +65,7 @@ export const Sortable = (props: SortableProps) => {
<Flex direction="column" gap={2} className={extraClasses}>
{itemsWithIds.map(({ id, item }) => (
<SortableItem key={id} id={id}>
{item}
<CollapsableText text={item} />
</SortableItem>
))}
</Flex>
@@ -4,7 +4,7 @@ import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
interface TaskControlsProps {
export interface TaskControlsProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
@@ -30,11 +30,11 @@ export const TaskControls = (props: TaskControlsProps) => {
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
{endTask.task.type !== "task_done" ? (
<SubmitButton data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
<SubmitButton colorScheme="blue" data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
Submit
</SubmitButton>
) : (
<SubmitButton data-cy="next-task" onClick={props.onSkip}>
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onSkip}>
Next Task
</SubmitButton>
)}
@@ -0,0 +1,60 @@
import {
Button,
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
useDisclosure,
} from "@chakra-ui/react";
import { TaskControls, TaskControlsProps } from "./TaskControls";
interface TaskControlsOverridableProps extends TaskControlsProps {
isValid: boolean;
prepareForSubmit: () => void;
}
export const TaskControlsOverridable = (props: TaskControlsOverridableProps) => {
const { isValid, onSubmitResponse, ...rest } = props;
const { isOpen: isModalOpen, onOpen: onOpenModal, onClose: onModalClose } = useDisclosure();
const unchangedResponsePrompt = () => {
onOpenModal();
// Ideally this happens when the user clicks submit, but we can't
// reliably wait for it to be executed before submitting the response
// without significant refactoring.
// As a result, modal will only display once even if the user doesn't proceed
props.prepareForSubmit();
};
const onSubmitResponseOverride = () => {
onSubmitResponse(props.tasks[0]);
onModalClose();
};
return (
<>
<Modal isOpen={isModalOpen} onClose={onModalClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalCloseButton />
<ModalHeader>Order Unchanged</ModalHeader>
<ModalBody>You have not changed the order of the prompts. Are you sure you would like to submit?</ModalBody>
<ModalFooter>
<Flex justify="center" ml="auto" gap={2}>
<Button variant={"ghost"} onClick={onModalClose}>
Cancel
</Button>
<Button onClick={onSubmitResponseOverride}>Submit anyway</Button>
</Flex>
</ModalFooter>
</ModalContent>
</Modal>
<TaskControls onSubmitResponse={isValid ? props.onSubmitResponse : unchangedResponsePrompt} {...rest} />
</>
);
};
+25
View File
@@ -0,0 +1,25 @@
import { Container } from "./Container";
import Image from "next/image";
const Vision = () => {
return (
<div className="bg-gray-900 py-20">
<Container className="">
<div className="grid gap-16 items-center py-20 md:py-32 lg:grid-cols-2">
<div>
<h2 className="text-4xl text-white mb-6">Our Vision</h2>
<p className="text-2xl text-gray-400">
We want OpenAssistant to be the single, unifying platform that all other systems use to interface with
humans.
</p>
</div>
<div className="m-auto rounded-tl-[45px] rounded-br-[45px] overflow-hidden">
<Image src="/images/temp-avatars/av2.jpg" width={450} height={450} alt="temp-image" />
</div>
</div>
</Container>
</div>
);
};
export default Vision;
+98
View File
@@ -0,0 +1,98 @@
import { JWT } from "next-auth/jwt";
declare global {
// eslint-disable-next-line no-var
var oasstApiClient: OasstApiClient | undefined;
}
class OasstError {
message: string;
errorCode: number;
httpStatusCode: number;
constructor(message: string, errorCode: number, httpStatusCode: number) {
this.message = message;
this.errorCode = errorCode;
this.httpStatusCode = httpStatusCode;
}
}
export class OasstApiClient {
constructor(private readonly oasstApiUrl: string, private readonly oasstApiKey: string) {}
private async post(path: string, body: any): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "POST",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
});
if (resp.status == 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
try {
const error = JSON.parse(errorText);
throw new OasstError(error.message, error.error_code, resp.status);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
}
return await resp.json();
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
// This is a raw Json type, so we can't use it to strongly type the task.
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
return this.post("/api/v1/tasks/", {
type: taskType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
auth_method: "local",
},
});
}
async ackTask(taskId: string, messageId: string): Promise<void> {
return this.post(`/api/v1/tasks/${taskId}/ack`, {
message_id: messageId,
});
}
// TODO return a strongly typed Task?
// This method is used to record interaction with task while fetching next task.
// This is a raw Json type, so we can't use it to strongly type the task.
async interactTask(
updateType: string,
messageId: string,
userMessageId: string,
content: object,
userToken: JWT
): Promise<any> {
return this.post("/api/v1/tasks/interaction", {
type: updateType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
auth_method: "local",
},
message_id: messageId,
user_message_id: userMessageId,
...content,
});
}
}
export const oasstApiClient =
globalThis.oasstApiClient || new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
if (process.env.NODE_ENV !== "production") {
globalThis.oasstApiClient = oasstApiClient;
}
+1 -1
View File
@@ -4,5 +4,5 @@ export { default } from "next-auth/middleware";
* Guards all pages under `/grading` and redirects them to the sign in page.
*/
export const config = {
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*"],
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
};
+39
View File
@@ -0,0 +1,39 @@
import { Container } from "src/components/Container";
import Services from "src/components/Services";
import Vision from "src/components/Vision";
import Roadmap from "src/components/Roadmap";
import { CallToAction } from "src/components/CallToAction";
import Image from "next/image";
const AboutPage = () => {
return (
<div>
<Container className="">
<div className="grid gap-16 items-center py-20 md:py-40 lg:grid-cols-2">
<div className="m-auto order-2 lg:order-1">
<Image src="/images/logos/logo.png" width={450} height={450} alt="temp-image" />
</div>
<div className="space-y-8 order-1 lg:order-2">
<div>
<h1 className="text-4xl mb-6">What is OpenAssistant?</h1>
<p className="text-2xl">
OpenAssistant is a chat-based assistant that understands tasks, can interact with third-party systems,
and retrieve information dynamically to do so.
</p>
</div>
<p className="text-2xl">
It can be extended and personalized easily and is developed as free, open-source software.
</p>
</div>
</div>
</Container>
<Services />
<Vision />
<Roadmap />
<CallToAction />
</div>
);
};
export default AboutPage;
+20 -17
View File
@@ -7,6 +7,7 @@ import React, { useState } from "react";
export default function Account() {
const { data: session } = useSession();
const [username, setUsername] = useState("");
const updateUser = async (e: React.SyntheticEvent) => {
e.preventDefault();
try {
@@ -16,6 +17,7 @@ export default function Account() {
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});
session.user.name = username;
await Router.push("/account");
} catch (error) {
console.error(error);
@@ -34,23 +36,24 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<form onSubmit={updateUser}>
<InputGroup>
<Input
onChange={(e) => setUsername(e.target.value)}
placeholder="Edit Username"
type="text"
value={username}
></Input>
<Button disabled={!username} type="submit" value="Change">
Submit
</Button>
</InputGroup>
</form>
<p>{session.user.email}</p>
</main>
<div className="oa-basic-theme">
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<form onSubmit={updateUser}>
<InputGroup>
<Input
onChange={(e) => setUsername(e.target.value)}
placeholder="Edit Username"
type="text"
value={username}
></Input>
<Button disabled={!username} type="submit" value="Change">
Submit
</Button>
</InputGroup>
</form>
</main>
</div>
</>
);
}
+9 -7
View File
@@ -19,13 +19,15 @@ export default function Account() {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<main className="h-3/4 z-0 bg-white flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
<div className="oa-basic-theme">
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
<p>{session.user.name || "No username"}</p>
<Button>
<Link href="/account/edit">Edit Username</Link>
</Button>
<p>{session.user.email}</p>
</main>
</div>
</>
);
}
@@ -59,6 +59,17 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development")
);
}
// Create a map of provider types to a set of admin user identifiers based on
// the environment variables. We assume the list is separated by ',' and each
// entry is separated by ':'.
const adminUserMap = process.env.ADMIN_USERS.split(",").reduce((result, entry) => {
const [authType, id] = entry.split(":");
const s = result.get(authType) || new Set();
s.add(id);
result.set(authType, s);
return result;
}, new Map());
export const authOptions: AuthOptions = {
// Ensure we can store user data in a database.
adapter: PrismaAdapter(prisma),
@@ -68,6 +79,56 @@ export const authOptions: AuthOptions = {
verifyRequest: "/auth/verify",
// error: "/auth/error", -Will be used later
},
callbacks: {
/**
* Ensure we propagate the user's role when creating the session from the
* token.
*/
async session({ session, token }) {
session.user.role = token.role;
return session;
},
/**
* When creating a token, fetch the user's role and inject it in the token.
* This let's use forward the role to the session object.
*/
async jwt({ token }) {
const { role } = await prisma.user.findUnique({
where: { id: token.sub },
select: { role: true },
});
token.role = role;
return token;
},
},
events: {
/**
* Update the user's role after they have successfully signed in
*/
async signIn({ user, account }) {
// Get the admin list for the user's auth type.
const adminForAccountType = adminUserMap.get(account.provider);
// Return early if there's no admin list.
if (!adminForAccountType) {
return;
}
// TODO(#236): Reduce the number of times we update the role field.
// Update the database if the user is an admin.
if (adminForAccountType.has(account.providerAccountId)) {
await prisma.user.update({
data: {
role: "admin",
},
where: {
id: user.id,
},
});
}
},
},
session: {
strategy: "jwt",
},
+24
View File
@@ -0,0 +1,24 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
+29
View File
@@ -0,0 +1,29 @@
import { getToken } from "next-auth/jwt";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
//TODO: add params if needed
const params = new URLSearchParams({
username: token.sub,
});
const messagesRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/messages?${params}`, {
method: "GET",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
},
});
const messages = await messagesRes.json();
// Send recieved messages to the client.
res.status(200).json(messages);
};
export default handler;
+3 -28
View File
@@ -1,4 +1,5 @@
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
/**
@@ -21,24 +22,7 @@ const handler = async (req, res) => {
}
// Fetch the new task.
//
// This needs to be refactored into an easier to use library.
const taskRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
type: task_type,
user: {
id: token.sub,
display_name: token.name || token.email,
auth_method: "local",
},
}),
});
const task = await taskRes.json();
const task = await oasstApiClient.fetchTask(task_type, token);
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -53,16 +37,7 @@ const handler = async (req, res) => {
});
// Update the backend with our Task ID
await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/${task.id}/ack`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
message_id: registeredTask.id,
}),
});
await oasstApiClient.ackTask(task.id, registeredTask.id);
// Send the results to the client.
res.status(200).json(registeredTask);
+2 -21
View File
@@ -1,4 +1,5 @@
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
/**
@@ -34,27 +35,7 @@ const handler = async (req, res) => {
},
});
// Send the interaction to the Task Backend. This automatically fetches the
// next task in the sequence (or the done task).
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/tasks/interaction`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
type: update_type,
user: {
id: token.sub,
display_name: token.name || token.email,
auth_method: "local",
},
message_id: id,
user_message_id: interaction.id,
...content,
}),
});
const newTask = await interactionRes.json();
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
// Stores the new task with our database.
const newRegisteredTask = await prisma.registeredTask.create({
+1
View File
@@ -1,4 +1,5 @@
import { getSession } from "next-auth/react";
import prisma from "../../lib/prismadb";
// POST /api/post
// Required fields in body: title
+4 -3
View File
@@ -15,13 +15,13 @@ function Signin({ csrfToken, providers }) {
const emailEl = useRef(null);
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
};
const debugUsernameEl = useRef(null);
function signinWithDebugCredentials(ev: React.FormEvent) {
ev.preventDefault();
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
signIn(credentials.id, { callbackUrl: "/dashboard", username: debugUsernameEl.current.value });
}
const { colorMode } = useColorMode();
@@ -98,7 +98,8 @@ function Signin({ csrfToken, providers }) {
</Button>
)}
</Stack>
<div className="pt-10 text-center">
<hr className="mt-14 mb-4 h-px bg-gray-200 border-0" />
<div className="text-center">
By signing up you agree to our <br></br>
<Link href="/terms-of-service" aria-label="Terms of Service" className="hover:underline underline-offset-4">
<b>Terms of Service</b>
+7 -1
View File
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const AssistantReply = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
+7 -1
View File
@@ -1,6 +1,6 @@
import { Container, Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
@@ -27,6 +27,12 @@ const InitialPrompt = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
+7 -1
View File
@@ -1,6 +1,6 @@
import { Textarea } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
@@ -21,6 +21,12 @@ const UserReply = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
+33
View File
@@ -0,0 +1,33 @@
import { Box, useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { getDashboardLayout } from "src/components/Layout";
import { LeaderboardTable, SideMenu, TaskOption } from "src/components/Dashboard";
import { colors } from "styles/Theme/colors";
const Dashboard = () => {
const { colorMode } = useColorMode();
return (
<>
<Head>
<title>Dashboard - Open Assistant</title>
<meta name="description" content="Chat with Open Assistant and provide feedback." />
</Head>
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
<Box className="sm:flex h-full gap-6">
<Box className="p-6 sm:pr-0">
<SideMenu />
</Box>
<Box className="flex flex-col overflow-auto p-6 sm:pl-0 gap-14">
<TaskOption />
<LeaderboardTable />
</Box>
</Box>
</Box>
</>
);
};
Dashboard.getLayout = (page) => getDashboardLayout(page);
export default Dashboard;
@@ -1,10 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -24,6 +26,12 @@ const RankAssistantReplies = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
@@ -64,6 +72,7 @@ const RankAssistantReplies = () => {
}
const replies = tasks[0].task.replies as string[];
const messages = tasks[0].task.conversation.messages as Message[];
return (
<>
@@ -77,10 +86,17 @@ const RankAssistantReplies = () => {
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<ContextMessages messages={messages} />
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.replies.length}
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
</>
);
@@ -1,10 +1,10 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -32,6 +32,12 @@ const RankInitialPrompts = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const submitResponse = (task) => {
trigger({
id: task.id,
@@ -79,7 +85,13 @@ const RankInitialPrompts = () => {
<Sortable items={tasks[0].task.prompts} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.prompts.length}
prepareForSubmit={() => setRanking(tasks[0].task.prompts.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
</>
);
@@ -1,10 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useState } from "react";
import { useEffect, useState } from "react";
import { ContextMessages } from "src/components/ContextMessages";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { Sortable } from "src/components/Sortable/Sortable";
import { SurveyCard } from "src/components/Survey/SurveyCard";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverridable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
@@ -24,6 +26,12 @@ const RankUserReplies = () => {
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
@@ -63,6 +71,7 @@ const RankUserReplies = () => {
);
}
const replies = tasks[0].task.replies as string[];
const messages = tasks[0].task.conversation.messages as Message[];
return (
<>
@@ -76,10 +85,17 @@ const RankUserReplies = () => {
<p className="text-lg py-1">
Given the following replies, sort them from best to worst, best being first, worst being last.
</p>
<ContextMessages messages={messages} />
<Sortable items={replies} onChange={setRanking} className="my-8" />
</SurveyCard>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControlsOverridable
tasks={tasks}
isValid={ranking.length == tasks[0].task.replies.length}
prepareForSubmit={() => setRanking(tasks[0].task.replies.map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
/>
</div>
</>
);
+5 -15
View File
@@ -1,16 +1,10 @@
import Head from "next/head";
import { useSession } from "next-auth/react";
import { CallToAction } from "src/components/CallToAction";
import { Faq } from "src/components/Faq";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
import { Hero } from "src/components/Hero";
import { getTransparentHeaderLayout } from "src/components/Layout";
import { TaskSelection } from "src/components/TaskSelection";
const Home = () => {
const { data: session } = useSession();
return (
<>
<Head>
@@ -20,15 +14,11 @@ const Home = () => {
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
{session ? (
<TaskSelection />
) : (
<main className="oa-basic-theme">
<Hero />
<CallToAction />
<Faq />
</main>
)}
<main className="oa-basic-theme">
<Hero />
<CallToAction />
<Faq />
</main>
</>
);
};

Some files were not shown because too many files have changed in this diff Show More