This commit is contained in:
Lewis Tunstall
2023-01-04 15:21:23 +11:00
240 changed files with 35119 additions and 9731 deletions
+93
View File
@@ -0,0 +1,93 @@
# devcontainer
## example usage
Below are some example use cases you might want to run from within the
devcontainer (either
[within VSCode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
or in your browser via
[GitHub Codespaces](https://github.com/features/codespaces)).
### Run pre-commit
```bash
# run pre-commit
pre-commit run --all-files
```
A successfull run should look something like this:
```
@andrewm4894 ➜ /workspaces/Open-Assistant (devcontainer-improvements) $ pre-commit run --all-files
[INFO] Initializing environment for https://github.com/pre-commit/pre-commit-hooks.
[INFO] Initializing environment for https://github.com/psf/black.
[INFO] Initializing environment for https://github.com/psf/black:.[jupyter].
[INFO] Initializing environment for https://github.com/pycqa/flake8.
[INFO] Initializing environment for https://github.com/pycqa/isort.
[INFO] Initializing environment for https://github.com/pre-commit/mirrors-prettier.
[INFO] Initializing environment for https://github.com/pre-commit/mirrors-prettier:prettier@2.7.1.
[INFO] Initializing environment for local.
[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
[INFO] Installing environment for https://github.com/psf/black.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
[INFO] Installing environment for https://github.com/pycqa/flake8.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
[INFO] Installing environment for https://github.com/pycqa/isort.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
[INFO] Installing environment for https://github.com/pre-commit/mirrors-prettier.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
[INFO] Installing environment for local.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check yaml...............................................................Passed
check json...............................................................Passed
check for case conflicts.................................................Passed
detect private key.......................................................Passed
fix python encoding pragma...............................................Passed
forbid submodules....................................(no files to check)Skipped
mixed line ending........................................................Passed
fix requirements.txt.....................................................Passed
check that executables have shebangs.....................................Passed
check that scripts with shebangs are executable..........................Passed
check BOM - deprecated: use fix-byte-order-marker........................Passed
check for broken symlinks............................(no files to check)Skipped
check for merge conflicts................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
black-jupyter............................................................Passed
flake8...................................................................Passed
isort....................................................................Passed
prettier.................................................................Passed
Lint website.............................................................Passed
```
### Docker compose
```bash
# build the image
docker compose up --build
```
You should see some docker containers being pulled and activated.
Once you see a line like:
```
open-assistant-web-1 | Listening on port 3000 url: http://localhost:3000
```
you should be able to access that port like below:
<img width="640" alt="image" src="https://user-images.githubusercontent.com/2178292/210395676-e9c2aab5-cb54-4ae6-b1eb-ac929fd73607.png">
this port can then be forwarded to a browser tab like below:
<img width="640" alt="image" src="https://user-images.githubusercontent.com/2178292/210396207-1b2e259f-4d5d-475d-b225-91e2bd004071.png">
+15
View File
@@ -0,0 +1,15 @@
{
"name": "Open-Assistant",
"image": "mcr.microsoft.com/vscode/devcontainers/universal",
"features": {
"ghcr.io/devcontainers-contrib/features/pre-commit:2": {
"version": "latest"
}
},
"postCreateCommand": "bash .devcontainer/post_create_command.sh",
"customizations": {
"vscode": {
"extensions": ["GitHub.copilot"]
}
}
}
+2
View File
@@ -0,0 +1,2 @@
# ensure pre-commit is installed
pre-commit install
+1
View File
@@ -0,0 +1 @@
**/node_modules
+2 -2
View File
@@ -12,7 +12,7 @@ on:
workflow_call:
jobs:
build-frontend:
build-frontend:
runs-on: ubuntu-latest
defaults:
run:
@@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-node@v3
with:
node-version: 16.x
cache: 'npm'
cache: "npm"
cache-dependency-path: website/package-lock.json
- run: npm ci
- run: npm run build
+30
View File
@@ -0,0 +1,30 @@
name: Test API Contract
on:
push:
branches:
- main
pull_request:
workflow_call:
jobs:
test-contract:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- run: cd oasst-shared && pip install -e .
- run: cd oasst-shared && pip install -r requirements.dev.txt
- run: cd backend && pip install -r requirements.txt
- run: ./scripts/backend-development/start-mock-server.sh
- name: Run contract tests
run: ./scripts/oasst-shared-development/test.sh
- run: ./scripts/backend-development/stop-mock-server.sh
+3
View File
@@ -5,3 +5,6 @@
*.egg-info
__pycache__
.DS_Store
# Generated files
backend/oasst-openapi.json
+36 -9
View File
@@ -1,7 +1,32 @@
exclude: "build|stubs|^bot/templates/|openassistant/templates"
# WARNING!
#
# When making changes to auto-formatters used in pre-commit hooks, you are
# likely to cause merge conflicts with main and/or other pull requests.
# Fixing them might revert other people's work. Expect pain!
# To avoid accidental reversions and keep it easy to review, please make sure
# that changes here are in a pull request by themselves, that it consists of
# two commits:
#
# 1. The changes to this file
# 2. Changes made by running `python3 -m pre_commit run --all-files`.
#
# Then each time your pull request is blocked by a merge conflict, do the
# following steps:
#
# git reset HEAD^1 && git checkout -f # discard the change commit
# git rebase main # re-apply other people's changes
# python3 -m pre_commit run --all-files # re-run the rules
# git add . # add the newly changed files
# git commit -m 'apply pre-commit' # commit it
# git push -f # force push back to your branch
#
# Keep in mind you may have to do this a few times, as changes here may impact
# other pull requests. Try to keep it up-to-date so they can go in when it'll
# cause least disruption.
#
# /WARNING!
default_language_version:
python: python3
exclude: "build|stubs|^bot/templates/|openassistant/templates/$"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -14,11 +39,12 @@ repos:
# and which break the standard YAML check. The alternative would be to
# skip any unsafe errors (and thus break YAML compatibility) or use
# some other checker that may not work in general.
exclude: "^copilot/web/addons/.*$"
exclude: ^copilot/.*/addons/.*$
- id: check-json
- id: check-case-conflict
- id: detect-private-key
- id: fix-encoding-pragma
args: [--remove]
- id: forbid-submodules
- id: mixed-line-ending
- id: requirements-txt-fixer
@@ -28,13 +54,13 @@ repos:
- id: check-symlinks
- id: check-merge-conflict
- id: check-added-large-files
args: ["--maxkb=1024"]
args: [--maxkb=1024]
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
@@ -50,14 +76,15 @@ repos:
rev: v2.7.1
hooks:
- id: prettier
args: ["--write"]
args: [--prose-wrap=always, --write]
- repo: local
hooks:
- id: next-lint-website
name: Lint website
files: ^website/
exclude: ^website/node_modules/
types_or: [javascript, jsx, ts, tsx]
language: system
language: node
pass_filenames: false
entry: bash -c 'cd website && npm ci && npm run lint'
entry: website/next-lint.js
+2
View File
@@ -1,2 +1,4 @@
* @yk @andreaskoepf
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
/model/ @theblackcat102 @sanagno
/copilot/ @fozziethebeat @andreaskoepf @yk
+102 -35
View File
@@ -1,51 +1,95 @@
# Open-Assistant
Open Assistant is a project meant to give everyone access to a great chat based large language model.
Open Assistant is a project meant to give everyone access to a great chat based
large language model.
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
We believe that by doing this we will create a revolution in innovation in
language. In the same way that stable-diffusion helped the world make art and
images in new ways we hope Open Assistant can help improve the world by
improving language itself.
## Do you want to try it out?
If you are interested in taking a look at the current state of the project, You can set up an entire stack needed to run **Open-Assistant**, including the
If you are interested in taking a look at the current state of the project, you
can set up an entire stack needed to run **Open-Assistant**, including the
website, backend, and associated dependent services.
To start the demo, Run this in the root directory of the repository:
To start the demo, run this in the root directory of the repository:
```sh
docker compose up --build
```
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and interact with the website.
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
interact with the website.
**Note:** When logging in via email, navigate to `http://localhost:1080` to get the magic email login link.
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
the magic email login link.
**Note:** If you would like to run this in a standardized development
environment (a
["devcontainer"](https://code.visualstudio.com/docs/devcontainers/containers))
using
[vscode locally](https://code.visualstudio.com/docs/devcontainers/create-dev-container#_create-a-devcontainerjson-file)
or in a web browser using
[GitHub Codespaces](https://github.com/features/codespaces), you can use the
provided [`.devcontainer`](.devcontainer/) folder.
## The Plan
We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
We want to get to an initial MVP as fast as possible, by following the 3-steps
outlined in the InstructGPT paper.
1. Collect high-quality human generated Instruction-Fulfillment samples (prompt + response), goal >50k. We design a crowdsourced process to collect and reviewed prompts. We do not want to train on flooding/toxic/spam/junk/personal information data. We will have a leaderboard to motivate the community that shows progress and the most active users. Swag will be given to the top-contributors.
2. For each of the collected prompts we will sample multiple completions. Completions of one prompt will then be shown randomly to users to rank them from best to worst. Again this should happen crowd-sourced, e.g. we need to deal with unreliable potentially malicious users. At least multiple votes by independent users have to be collected to measure the overall agreement. The gathered ranking-data will be used to train a reward model.
3. Now follows the RLHF training phase based on the prompts and the reward model.
1. Collect high-quality human generated Instruction-Fulfillment samples
(prompt + response), goal >50k. We design a crowdsourced process to collect
and reviewed prompts. We do not want to train on
flooding/toxic/spam/junk/personal information data. We will have a
leaderboard to motivate the community that shows progress and the most active
users. Swag will be given to the top-contributors.
2. For each of the collected prompts we will sample multiple completions.
Completions of one prompt will then be shown randomly to users to rank them
from best to worst. Again this should happen crowd-sourced, e.g. we need to
deal with unreliable potentially malicious users. At least multiple votes by
independent users have to be collected to measure the overall agreement. The
gathered ranking-data will be used to train a reward model.
3. Now follows the RLHF training phase based on the prompts and the reward
model.
We can then take the resulting model and continue with completion sampling step 2 for a next iteration.
We can then take the resulting model and continue with completion sampling step
2 for a next iteration.
## The Vision
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
We are not going to stop at replicating ChatGPT. We want to build the assistant
of the future, able to not only write email and cover letters, but do meaningful
work, use APIs, dynamically research information, and much more, with the
ability to be personalized and extended by anyone. And we want to do this in a
way that is open and accessible, which means we must not only build a great
assistant, but also make it small and efficient enough to run on consumer
hardware.
### Slide Decks
[Vision & Roadmap](https://docs.google.com/presentation/d/1n7IrAOVOqwdYgiYrXc8Sj0He8krn5MVZO_iLkCjTtu0/edit?usp=sharing)
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
## How can you help?
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
All open source projects begin with people like you. Open source is the belief
that if we collaborate we can together gift our knowledge and technology to the
world for the benefit of humanity.
## Im in! Now what?
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
this is for work coordination.
[and / or the YK Discord Server](https://ykilcher.com/discord)
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
@@ -53,30 +97,41 @@ All open source projects begins with people like you. Open source is the belief
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include
in your comment a brief description of how you'll solve the problem and if
there are any open questions you want to discuss. Once a project coordinator
has assigned the issue to you, start working on it.
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in
Discord and someone can help clarify the issue with more detail.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
contributions smoothly we recommend the following:
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
and clone it to your local machine. (Read more
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
1. Before working on any changes, try to
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
to keep it up-to-date with the upstream repository.
1. Work on a small focused change that only touches on a few files.
1. Run `pre-commit` and make sure all files have formatting fixed. This
simplifies life for reviewers.
1. Package up a small bit of work that solves part of the problem into a Pull
Request and send it out for review
1. Package up a small bit of work that solves part of the problem
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
and
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
1. If you're lucky, we can merge your change into `main` without any problems.
If there's changes to files you're working on, resolve them by:
1. First try rebase as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase)
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
1. If rebase feels too painful, merge as suggested
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge)
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
1. Once you've resolved any conflicts, finish the review and merge into `main`.
1. Merge in your change and move onto a new issue or the second step of your
current issue.
@@ -95,20 +150,27 @@ addressed now, or filing an issue to handle it later.
## Developer Setup
Work is organized in the [project board](https://github.com/orgs/LAION-AI/projects/3).
Work is organized in the
[project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs. Meaning we'd be happy for anyone to do these tasks.**
**Anything that is in the `Todo` column and not assigned, is up for grabs.
Meaning we'd be happy for anyone to do these tasks.**
If you want to work on something, assign yourself to it or write a comment that you want to work on it and what you plan to do.
If you want to work on something, assign yourself to it or write a comment that
you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at `scripts/frontend-development/README.md` to make a backend available.
- To get started with development, if you want to work on the backend, have a
look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at
`scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend` folder.
There is also a minimal implementation of a frontend in the `text-frontend`
folder.
We are using Python 3.10 for the backend.
Check out the [High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
Check out the
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
@@ -116,10 +178,15 @@ The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit hooks.
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
Upon making a release on GitHub, all docker images are automatically built and
pushed to ghcr.io. The docker images are tagged with the release version, and
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
automatically deploy the built release to the dev machine.
+5 -2
View File
@@ -2,7 +2,9 @@
## REST Server Configuration
Please either use environment variables or create a `.env` file in the backend root directory (in which this readme file is located) to specify the `DATABASE_URI`.
Please either use environment variables or create a `.env` file in the backend
root directory (in which this readme file is located) to specify the
`DATABASE_URI`.
Example contents of a `.env` file for the backend:
@@ -14,4 +16,5 @@ BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://local
## Running the REST Server locally for development
Have a look into the main `README.md` file for more information on how to set up the backend for development.
Have a look into the main `README.md` file for more information on how to set up
the backend for development.
-1
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from logging.config import fileConfig
import sqlmodel
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""first revision
Revision ID: 23e5fea252dd
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""v1 db structure
Revision ID: cd7de470586e
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""add auth_method to person
Revision ID: 6368515778c5
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""add_auth_method_to_ix_person_username
Revision ID: 0daec5f8135f
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""Adds text labels table.
Revision ID: 067c4002f2d9
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""add_journal_table
Revision ID: 3358eb6834e6
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""post ref for work_package
Revision ID: d24b37426857
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""Added lang column for ISO-639-1 codes
Revision ID: ef0b52902560
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""add collective flag to task
Revision ID: 464ec4667aae
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""add field trusted api client
Revision ID: 73ce3675c1f5
@@ -0,0 +1,338 @@
"""name changes: person->user, post->message, work_package->task
Revision ID: abb47e9d145a
Revises: 73ce3675c1f5
Create Date: 2022-12-30 20:54:49.880568
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "abb47e9d145a"
down_revision = "73ce3675c1f5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# clear DB
op.execute("DELETE FROM journal;")
op.execute("DELETE FROM work_package;")
op.execute("DELETE FROM post_reaction;")
op.execute("DELETE FROM post;")
op.execute("DELETE FROM person_stats;")
op.execute("DELETE FROM person;")
op.execute("DELETE FROM text_labels;")
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("auth_method", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_user_username", "user", ["api_client_id", "username", "auth_method"], unique=True)
op.create_table(
"message",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("parent_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("role", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("lang", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_message_frontend_message_id", "message", ["api_client_id", "frontend_message_id"], unique=True)
op.create_index(op.f("ix_message_message_tree_id"), "message", ["message_tree_id"], unique=False)
op.create_index(op.f("ix_message_task_id"), "message", ["task_id"], unique=False)
op.create_index(op.f("ix_message_user_id"), "message", ["user_id"], unique=False)
op.create_table(
"task",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("expiry_date", sa.DateTime(), nullable=True),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("ack", sa.Boolean(), nullable=True),
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("parent_message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_task_user_id"), "task", ["user_id"], unique=False)
op.create_table(
"user_stats",
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("leader_score", sa.Integer(), nullable=False),
sa.Column("reactions", sa.Integer(), nullable=False),
sa.Column("messages", sa.Integer(), nullable=False),
sa.Column("upvotes", sa.Integer(), nullable=False),
sa.Column("downvotes", sa.Integer(), nullable=False),
sa.Column("task_reward", sa.Integer(), nullable=False),
sa.Column("compare_wins", sa.Integer(), nullable=False),
sa.Column("compare_losses", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("user_id"),
)
op.create_table(
"message_reaction",
sa.Column("task_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["task_id"],
["task.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("task_id", "user_id"),
)
op.drop_constraint("text_labels_post_id_fkey", "text_labels", type_="foreignkey")
op.drop_constraint("journal_post_id_fkey", "journal", type_="foreignkey")
op.drop_constraint("journal_person_id_fkey", "journal", type_="foreignkey")
op.drop_table("post_reaction")
op.drop_index("ix_post_frontend_post_id", table_name="post")
op.drop_index("ix_post_person_id", table_name="post")
op.drop_index("ix_post_thread_id", table_name="post")
op.drop_index("ix_post_workpackage_id", table_name="post")
op.drop_table("post")
op.drop_index("ix_work_package_person_id", table_name="work_package")
op.drop_table("work_package")
op.drop_table("person_stats")
op.drop_index("ix_person_username", table_name="person")
op.drop_table("person")
op.add_column("journal", sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.add_column("journal", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.drop_index("ix_journal_person_id", table_name="journal")
op.create_index(op.f("ix_journal_user_id"), "journal", ["user_id"], unique=False)
op.create_foreign_key(None, "journal", "user", ["user_id"], ["id"])
op.create_foreign_key(None, "journal", "message", ["message_id"], ["id"])
op.drop_column("journal", "person_id")
op.drop_column("journal", "post_id")
op.add_column("text_labels", sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True))
op.create_foreign_key(None, "text_labels", "message", ["message_id"], ["id"])
op.drop_column("text_labels", "post_id")
# ### end Alembic commands ###
def downgrade() -> None:
# clear DB
op.execute("DELETE FROM journal;")
op.execute("DELETE FROM message_reaction;")
op.execute("DELETE FROM task;")
op.execute("DELETE FROM message;")
op.execute("DELETE FROM user_stats;")
op.execute('DELETE FROM "user";')
op.execute("DELETE FROM text_labels;")
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("text_labels", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.drop_constraint("text_labels_message_id_fkey", "text_labels", type_="foreignkey")
op.drop_column("text_labels", "message_id")
op.add_column("journal", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.add_column("journal", sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.drop_constraint("journal_message_id_fkey", "journal", type_="foreignkey")
op.drop_constraint("journal_user_id_fkey", "journal", type_="foreignkey")
op.drop_index(op.f("ix_journal_user_id"), table_name="journal")
op.create_index("ix_journal_person_id", "journal", ["person_id"], unique=False)
op.drop_column("journal", "message_id")
op.drop_column("journal", "user_id")
op.create_table(
"person",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column("username", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.Column("display_name", sa.VARCHAR(length=256), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("auth_method", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="person_api_client_id_fkey"),
sa.PrimaryKeyConstraint("id", name="person_pkey"),
)
op.create_table(
"person_stats",
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("leader_score", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column(
"modified_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("reactions", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("posts", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("upvotes", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("downvotes", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("work_reward", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("compare_wins", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("compare_losses", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="person_stats_person_id_fkey"),
sa.PrimaryKeyConstraint("person_id", name="person_stats_pkey"),
)
op.create_table(
"work_package",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("expiry_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("done", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.Column("ack", sa.BOOLEAN(), autoincrement=False, nullable=True),
sa.Column("frontend_ref_post_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("parent_post_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("collective", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="work_package_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="work_package_person_id_fkey"),
sa.PrimaryKeyConstraint("id", name="work_package_pkey"),
)
op.create_index("ix_work_package_person_id", "work_package", ["person_id"], unique=False)
op.create_table(
"post",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column("parent_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("workpackage_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("role", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.Column("frontend_post_id", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
sa.Column("depth", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("children_count", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_person_id_fkey"),
sa.PrimaryKeyConstraint("id", name="post_pkey"),
)
op.create_index("ix_post_workpackage_id", "post", ["workpackage_id"], unique=False)
op.create_index("ix_post_thread_id", "post", ["thread_id"], unique=False)
op.create_index("ix_post_person_id", "post", ["person_id"], unique=False)
op.create_index("ix_post_frontend_post_id", "post", ["api_client_id", "frontend_post_id"], unique=False)
op.create_table(
"post_reaction",
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("work_package_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_reaction_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_reaction_person_id_fkey"),
sa.ForeignKeyConstraint(["work_package_id"], ["work_package.id"], name="post_reaction_work_package_id_fkey"),
)
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=False)
op.create_foreign_key("text_labels_post_id_fkey", "text_labels", "post", ["post_id"], ["id"])
op.create_foreign_key("journal_person_id_fkey", "journal", "person", ["person_id"], ["id"])
op.create_foreign_key("journal_post_id_fkey", "journal", "post", ["post_id"], ["id"])
op.drop_table("message_reaction")
op.drop_table("user_stats")
op.drop_index(op.f("ix_task_user_id"), table_name="task")
op.drop_table("task")
op.drop_index(op.f("ix_message_user_id"), table_name="message")
op.drop_index(op.f("ix_message_task_id"), table_name="message")
op.drop_index(op.f("ix_message_message_tree_id"), table_name="message")
op.drop_index("ix_message_frontend_message_id", table_name="message")
op.drop_table("message")
op.drop_index("ix_user_username", table_name="user")
op.drop_table("user")
# ### end Alembic commands ###
@@ -0,0 +1,27 @@
"""add deleted field to post
Revision ID: 8d269bc4fdbd
Revises: abb47e9d145a
Create Date: 2022-12-31 04:38:41.799206
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8d269bc4fdbd"
down_revision = "abb47e9d145a"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "deleted")
# ### end Alembic commands ###
+124 -63
View File
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from math import ceil
from pathlib import Path
from typing import Optional
@@ -7,13 +7,15 @@ import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.middleware.cors import CORSMiddleware
@@ -24,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
status_code=int(ex.http_status_code),
content=protocol_schema.OasstErrorResponse(
message=ex.message,
error_code=OasstErrorCode(ex.error_code),
).dict(),
)
@@ -63,14 +70,37 @@ if settings.UPDATE_ALEMBIC:
logger.exception("Alembic upgrade failed on startup")
if settings.RATE_LIMIT:
@app.on_event("startup")
async def connect_redis():
async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
"""Error callback function when too many requests"""
expire = ceil(pexpire / 1000)
raise OasstError(
f"Too Many Requests. Retry After {expire} seconds.",
OasstErrorCode.TOO_MANY_REQUESTS,
HTTPStatus.TOO_MANY_REQUESTS,
)
try:
redis_client = redis.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True
)
logger.info(f"Connected to {redis_client=}")
await FastAPILimiter.init(redis_client, http_callback=http_callback)
except Exception:
logger.exception("Failed to establish Redis connection")
if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
class DummyPost(pydantic.BaseModel):
task_post_id: str
user_post_id: str
parent_post_id: Optional[str]
class DummyMessage(pydantic.BaseModel):
task_message_id: str
user_message_id: str
parent_message_id: Optional[str]
text: str
role: str
@@ -81,96 +111,97 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
dummy_posts = [
DummyPost(
task_post_id="de111fa8",
user_post_id="6f1d0711",
parent_post_id=None,
dummy_messages = [
DummyMessage(
task_message_id="de111fa8",
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="user",
role="prompter",
),
DummyPost(
task_post_id="74c381d4",
user_post_id="4a24530b",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="74c381d4",
user_message_id="4a24530b",
parent_message_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyPost(
task_post_id="3d5dc440",
user_post_id="a8c01c04",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="3d5dc440",
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="user",
role="prompter",
),
DummyPost(
task_post_id="643716c1",
user_post_id="f43a93b7",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="user",
role="prompter",
),
DummyPost(
task_post_id="2e4e1e6",
user_post_id="c886920",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="2e4e1e6",
user_message_id="c886920",
parent_message_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyPost(
task_post_id="970c437d",
user_post_id="cec432cf",
parent_post_id=None,
DummyMessage(
task_message_id="970c437d",
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="user",
role="prompter",
),
DummyPost(
task_post_id="6066118e",
user_post_id="4f85f637",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="6066118e",
user_message_id="4f85f637",
parent_message_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyPost(
task_post_id="ba87780d",
user_post_id="0e276b98",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="ba87780d",
user_message_id="0e276b98",
parent_message_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]
for p in dummy_posts:
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
if wp and not wp.ack:
logger.warning("Deleting unacknowledged seed data work package")
db.delete(wp)
wp = None
if not wp:
if p.parent_post_id is None:
wp = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
print("p.parent_post_id", p.parent_post_id)
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
wp = pr.store_task(
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
)
),
thread_id=parent_post.thread_id,
parent_post_id=parent_post.id,
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_post_id(wp.id, p.task_post_id)
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
)
else:
logger.debug(f"seed data work_package found: {wp.id}")
logger.debug(f"seed data task found: {task.id}")
logger.info("Seed data check completed")
except Exception:
@@ -178,3 +209,33 @@ if settings.DEBUG_USE_SEED_DATA:
app.include_router(api_router, prefix=settings.API_V1_STR)
def get_openapi_schema():
return json.dumps(app.openapi())
if __name__ == "__main__":
# Importing here so we don't import packages unnecessarily if we're
# importing main as a module.
import argparse
import json
import uvicorn
parser = argparse.ArgumentParser()
parser.add_argument(
"--print-openapi-schema",
help="Dumps the openapi schema to stdout",
action=argparse.BooleanOptionalAction,
)
parser.add_argument("--host", help="The host to run the server")
parser.add_argument("--port", help="The port to run the server")
args = parser.parse_args()
if args.print_openapi_schema:
print(get_openapi_schema())
else:
uvicorn.run(app, host=args.host, port=args.port)
+79 -3
View File
@@ -1,16 +1,16 @@
# -*- coding: utf-8 -*-
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Security
from fastapi import Depends, Request, Response, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from fastapi_limiter.depends import RateLimiter
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlmodel import Session
@@ -64,3 +64,79 @@ def api_auth(
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
def get_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
return api_auth(api_key, db)
def get_trusted_api_client(
api_key: APIKey = Depends(get_api_key),
db: Session = Depends(get_db),
):
client = api_auth(api_key, db)
if not client.trusted:
raise OasstError(
"Forbidden",
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
return client
class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
user = (await request.json()).get("user")
return f"{api_key}:{user.get('id')}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip when api_key and user information are not available
# (such that it will be handled by `APIClientRateLimiter`)
if not api_key or not user or not user.get("id"):
return
return await super().__call__(request, response)
class APIClientRateLimiter(RateLimiter):
def __init__(
self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
return f"{api_key}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip if user information is available
# (such that it will be handled by `UserRateLimiter`)
if not api_key or user:
return
return await super().__call__(request, response)
+16 -2
View File
@@ -1,7 +1,21 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks, text_labels
from oasst_backend.api.v1 import (
frontend_messages,
frontend_users,
leaderboards,
messages,
stats,
tasks,
text_labels,
users,
)
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
@@ -0,0 +1,111 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@router.get("/{message_id}", response_model=protocol.Message)
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
if not isinstance(message.payload.payload, MessagePayload):
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return utils.prepare_message(message)
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -0,0 +1,52 @@
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return utils.prepare_message_list(messages)
@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
@@ -0,0 +1,25 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from sqlmodel import Session
router = APIRouter()
@router.get("/create/assistant")
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="prompter")
+147
View File
@@ -0,0 +1,147 @@
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/", response_model=list[protocol.Message])
def query_messages(
username: str = None,
api_client_id: str = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
allow_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if allow_deleted else False,
)
return utils.prepare_message_list(messages)
@router.get("/{message_id}", response_model=protocol.Message)
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
if not isinstance(message.payload.payload, MessagePayload):
# Unexptcted message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return utils.prepare_message(message)
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
+17
View File
@@ -0,0 +1,17 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@router.get("/", response_model=protocol.SystemStats)
def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_stats()
+67 -55
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import random
from typing import Any, Optional, Tuple
from uuid import UUID
@@ -7,10 +6,11 @@ from fastapi import APIRouter, Depends
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@@ -18,8 +18,8 @@ router = APIRouter()
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
thread_id = None
parent_post_id = None
message_tree_id = None
parent_message_id = None
match request.type:
case protocol_schema.TaskRequestType.random:
@@ -54,38 +54,42 @@ def generate_task(
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
posts = pr.fetch_random_conversation("assistant")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
case protocol_schema.TaskRequestType.prompter_reply:
logger.info("Generating a PrompterReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
posts = pr.fetch_random_conversation("user")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
messages = pr.fetch_random_conversation("prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
posts = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
@@ -93,18 +97,18 @@ def generate_task(
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankUserRepliesTask(
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=messages,
messages=task_messages,
),
replies=replies,
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
@@ -113,7 +117,7 @@ def generate_task(
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=messages),
conversation=protocol_schema.Conversation(messages=task_messages),
replies=replies,
)
case _:
@@ -121,10 +125,17 @@ def generate_task(
logger.info(f"Generated {task=}.")
return task, thread_id, parent_post_id
return task, message_tree_id, parent_message_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@router.post(
"/",
response_model=protocol_schema.AnyTask,
dependencies=[
Depends(deps.UserRateLimiter(times=100, minutes=5)),
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
],
) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
@@ -138,8 +149,8 @@ def request_task(
try:
pr = PromptRepository(db, api_client, request.user)
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id, request.collective)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -149,14 +160,14 @@ def request_task(
return task
@router.post("/{task_id}/ack")
def acknowledge_task(
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> Any:
) -> None:
"""
The frontend acknowledges a task.
"""
@@ -166,26 +177,25 @@ def acknowledge_task(
try:
pr = PromptRepository(db, api_client, user=None)
# here we store the post id in the database for the task
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
except Exception:
logger.exception("Failed to acknowledge task.")
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
return {}
@router.post("/{task_id}/nack")
def acknowledge_task_failure(
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> Any:
) -> None:
"""
The frontend reports failure to implement a task.
"""
@@ -200,8 +210,8 @@ def acknowledge_task_failure(
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@router.post("/interaction")
def post_interaction(
@router.post("/interaction", response_model=protocol_schema.TaskDone)
def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -216,29 +226,31 @@ def post_interaction(
pr = PromptRepository(db, api_client, user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToPost:
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
pr.store_text_reply(
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
return protocol_schema.TaskDone()
case protocol_schema.PostRating:
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)
# here we store the rating in the database
pr.store_rating(interaction)
return protocol_schema.TaskDone()
case protocol_schema.PostRanking:
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)
# TODO: check if the ranking is valid
@@ -254,7 +266,7 @@ def post_interaction(
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
@router.post("/close")
@router.post("/close", response_model=protocol_schema.TaskDone)
def close_collective_task(
close_task_request: protocol_schema.TaskClose,
db: Session = Depends(deps.get_db),
@@ -262,5 +274,5 @@ def close_collective_task(
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.post_id)
pr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
+2 -3
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import pydantic
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
@@ -7,7 +6,7 @@ from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -17,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel):
user: protocol_schema.User
@router.post("/")
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
+53
View File
@@ -0,0 +1,53 @@
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query user messages.
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return utils.prepare_message_list(messages)
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
+45
View File
@@ -0,0 +1,45 @@
from http import HTTPStatus
from uuid import UUID
from oasst_backend.models import Message
from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
def prepare_message(m: Message) -> protocol.Message:
if not isinstance(m.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
return protocol.Message(
id=m.id,
parent_id=m.parent_id,
text=m.payload.payload.text,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
return [prepare_message(m) for m in messages]
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv_messages.append(
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
)
return protocol.Conversation(messages=conv_messages)
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
tree_messages = []
for message in tree:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(prepare_message(message))
return protocol.MessageTree(id=tree_id, messages=tree_messages)
+4 -1
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, List, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
@@ -15,6 +14,10 @@ class Settings(BaseSettings):
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
RATE_LIMIT: bool = True
REDIS_HOST: str = "localhost"
REDIS_PORT: str = "6379"
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
-1
View File
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
__all__ = []
-1
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
+1 -2
View File
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from oasst_backend.config import settings
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlmodel import create_engine
if settings.DATABASE_URI is None:
+38 -39
View File
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
from oasst_backend.models import ApiClient, Journal, Task, User
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_shared.utils import utcnow
from pydantic import BaseModel
@@ -14,71 +13,71 @@ class JournalEventType(str, enum.Enum):
"""A label for a piece of text."""
user_created = "user_created"
text_reply_to_post = "text_reply_to_post"
post_rating = "post_rating"
post_ranking = "post_ranking"
text_reply_to_message = "text_reply_to_message"
message_rating = "message_rating"
message_ranking = "message_ranking"
@payload_type
class JournalEvent(BaseModel):
type: str
person_id: Optional[UUID]
post_id: Optional[UUID]
workpackage_id: Optional[UUID]
user_id: Optional[UUID]
message_id: Optional[UUID]
task_id: Optional[UUID]
task_type: Optional[str]
@payload_type
class TextReplyEvent(JournalEvent):
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
type: Literal[JournalEventType.text_reply_to_message] = JournalEventType.text_reply_to_message
length: int
role: str
@payload_type
class RatingEvent(JournalEvent):
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
type: Literal[JournalEventType.message_rating] = JournalEventType.message_rating
rating: int
@payload_type
class RankingEvent(JournalEvent):
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
type: Literal[JournalEventType.message_ranking] = JournalEventType.message_ranking
ranking: list[int]
class JournalWriter:
def __init__(self, db: Session, api_client: ApiClient, person: Person):
def __init__(self, db: Session, api_client: ApiClient, user: User):
self.db = db
self.api_client = api_client
self.person = person
self.person_id = self.person.id if self.person else None
self.user = user
self.user_id = self.user.id if self.user else None
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.text_reply_to_post,
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
payload=TextReplyEvent(role=role, length=length),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_rating,
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
payload=RatingEvent(rating=rating),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_ranking,
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
payload=RankingEvent(ranking=ranking),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log(
@@ -87,8 +86,8 @@ class JournalWriter:
payload: JournalEvent,
task_type: str,
event_type: str = None,
workpackage_id: Optional[UUID] = None,
post_id: Optional[UUID] = None,
task_id: Optional[UUID] = None,
message_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
if event_type is None:
@@ -97,22 +96,22 @@ class JournalWriter:
else:
event_type = type(payload).__name__
if payload.person_id is None:
payload.person_id = self.person_id
if payload.post_id is None:
payload.post_id = post_id
if payload.workpackage_id is None:
payload.workpackage_id = workpackage_id
if payload.user_id is None:
payload.user_id = self.user_id
if payload.message_id is None:
payload.message_id = message_id
if payload.task_id is None:
payload.task_id = task_id
if payload.task_type is None:
payload.task_type = task_type
entry = Journal(
person_id=self.person_id,
user_id=self.user_id,
api_client_id=self.api_client.id,
created_date=utcnow(),
event_type=event_type,
event_payload=PayloadContainer(payload=payload),
post_id=post_id,
message_id=message_id,
)
self.db.add(entry)
+10 -11
View File
@@ -1,20 +1,19 @@
# -*- coding: utf-8 -*-
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .message import Message
from .message_reaction import MessageReaction
from .task import Task
from .text_labels import TextLabels
from .work_package import WorkPackage
from .user import User
from .user_stats import UserStats
__all__ = [
"ApiClient",
"Person",
"PersonStats",
"Post",
"PostReaction",
"WorkPackage",
"User",
"UserStats",
"Message",
"MessageReaction",
"Task",
"TextLabels",
"Journal",
"JournalIntegration",
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from typing import Optional
from uuid import UUID, uuid4
+8 -9
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from typing import Literal
from oasst_backend.models.payload_column_type import payload_type
@@ -32,8 +31,8 @@ class InitialPromptPayload(TaskPayload):
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
class PrompterReplyPayload(TaskPayload):
type: Literal["prompter_reply"] = "prompter_reply"
conversation: protocol_schema.Conversation
hint: str | None
@@ -45,7 +44,7 @@ class AssistantReplyPayload(TaskPayload):
@payload_type
class PostPayload(BaseModel):
class MessagePayload(BaseModel):
text: str
@@ -56,13 +55,13 @@ class ReactionPayload(BaseModel):
@payload_type
class RatingReactionPayload(ReactionPayload):
type: Literal["post_rating"] = "post_rating"
type: Literal["message_rating"] = "message_rating"
rating: str
@payload_type
class RankingReactionPayload(ReactionPayload):
type: Literal["post_ranking"] = "post_ranking"
type: Literal["message_ranking"] = "message_ranking"
ranking: list[int]
@@ -81,10 +80,10 @@ class RankInitialPromptsPayload(TaskPayload):
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of prompter replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
@payload_type
+2 -3
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid1, uuid4
@@ -33,8 +32,8 @@ class Journal(SQLModel, table=True):
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")
event_type: str = Field(nullable=False, max_length=200)
@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlalchemy import false
from sqlmodel import Field, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class Post(SQLModel, table=True):
__tablename__ = "post"
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
class Message(SQLModel, table=True):
__tablename__ = "message"
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -20,12 +20,12 @@ class Post(SQLModel, table=True):
),
)
parent_id: UUID = Field(nullable=True)
thread_id: UUID = Field(nullable=False, index=True)
workpackage_id: UUID = Field(nullable=True, index=True)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
role: str = Field(nullable=False, max_length=128)
message_tree_id: UUID = Field(nullable=False, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_post_id: str = Field(max_length=200, nullable=False)
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
@@ -34,3 +34,4 @@ class Post(SQLModel, table=True):
lang: str = Field(nullable=False, max_length=200, default="en-US")
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID
@@ -10,14 +9,14 @@ from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class PostReaction(SQLModel, table=True):
__tablename__ = "post_reaction"
class MessageReaction(SQLModel, table=True):
__tablename__ = "message_reaction"
work_package_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
task_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
)
person_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
user_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import json
from typing import Any, Generic, Type, TypeVar
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
@@ -11,8 +10,8 @@ from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class WorkPackage(SQLModel, table=True):
__tablename__ = "work_package"
class Task(SQLModel, table=True):
__tablename__ = "task"
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -23,15 +22,15 @@ class WorkPackage(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
ack: Optional[bool] = None
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
frontend_ref_post_id: Optional[str] = None
thread_id: Optional[UUID] = None
parent_post_id: Optional[UUID] = None
frontend_message_id: Optional[str] = None
message_tree_id: Optional[UUID] = None
parent_message_id: Optional[UUID] = None
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
@property
+3 -2
View File
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
@@ -21,5 +20,7 @@ class TextLabels(SQLModel, table=True):
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
message_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
)
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
@@ -8,9 +7,9 @@ import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class Person(SQLModel, table=True):
__tablename__ = "person"
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
class User(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID
@@ -8,11 +7,11 @@ import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class PersonStats(SQLModel, table=True):
__tablename__ = "person_stats"
class UserStats(SQLModel, table=True):
__tablename__ = "user_stats"
person_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
leader_score: int = 0
modified_date: Optional[datetime] = Field(
@@ -20,9 +19,9 @@ class PersonStats(SQLModel, table=True):
)
reactions: int = 0 # reactions sent by user
posts: int = 0 # posts sent by user
messages: int = 0 # messages sent by user
upvotes: int = 0 # received upvotes (form other users)
downvotes: int = 0 # received downvotes (from other users)
work_reward: int = 0 # reward for workpackage completions
compare_wins: int = 0 # num times user's post won compare tasks
compare_losses: int = 0 # num times users's post lost compare tasks
task_reward: int = 0 # reward for task completions
compare_wins: int = 0 # num times user's message won compare tasks
compare_losses: int = 0 # num times users's message lost compare tasks
+491 -253
View File
@@ -1,263 +1,267 @@
# -*- coding: utf-8 -*-
import datetime
import random
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
self.db = db
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
self.journal = JournalWriter(db, api_client, self.person)
self.user = self.lookup_user(user)
self.user_id = self.user.id if self.user else None
self.journal = JournalWriter(db, api_client, self.user)
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
if not client_user:
return None
person: Person = (
self.db.query(Person)
user: User = (
self.db.query(User)
.filter(
Person.api_client_id == self.api_client.id,
Person.username == user.id,
Person.auth_method == user.auth_method,
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if person is None:
if user is None:
# user is unknown, create new record
person = Person(
username=user.id,
display_name=user.display_name,
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=user.auth_method,
auth_method=client_user.auth_method,
)
self.db.add(person)
self.db.add(user)
self.db.commit()
self.db.refresh(person)
elif user.display_name and user.display_name != person.display_name:
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
person.display_name = user.display_name
self.db.add(person)
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return person
return user
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
if not post_id:
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
def validate_frontend_message_id(self, message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
self.validate_frontend_message_id(frontend_message_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.frontend_ref_post_id = post_id
work_pack.ack = True
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.ack = False
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.add(task)
self.db.commit()
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
post: Post = (
self.db.query(Post)
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
self.validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
.one_or_none()
)
if fail_if_missing and post is None:
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
return post
if fail_if_missing and message is None:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.",
OasstErrorCode.MESSAGE_NOT_FOUND,
HTTP_404_NOT_FOUND,
)
return message
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
work_pack = (
self.db.query(WorkPackage)
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
self.validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return work_pack
return task
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
self.validate_post_id(post_id)
self.validate_post_id(user_post_id)
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
wp = self.fetch_workpackage_by_postid(post_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if wp is None:
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not wp.ack:
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
if wp.done:
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
# If there's no parent post assume user started new conversation
role = "user"
# If there's no parent message assume user started new conversation
role = "prompter"
depth = 0
if wp.parent_post_id:
parent_post = self.fetch_post(wp.parent_post_id)
parent_post.children_count += 1
self.db.add(parent_post)
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
parent_message.children_count += 1
self.db.add(parent_message)
depth = parent_post.depth + 1
if parent_post.role == "assistant":
role = "user"
depth = parent_message.depth + 1
if parent_message.role == "assistant":
role = "prompter"
else:
role = "assistant"
# create reply post
new_post_id = uuid4()
user_post = self.insert_post(
post_id=new_post_id,
frontend_post_id=user_post_id,
parent_id=wp.parent_post_id,
thread_id=wp.thread_id or new_post_id,
workpackage_id=wp.id,
# create reply message
new_message_id = uuid4()
user_message = self.insert_message(
message_id=new_message_id,
frontend_message_id=user_frontend_message_id,
parent_id=task.parent_message_id,
message_tree_id=task.message_tree_id or new_message_id,
task_id=task.id,
role=role,
payload=db_payload.PostPayload(text=text),
payload=db_payload.MessagePayload(text=text),
depth=depth,
)
if not wp.collective:
wp.done = True
self.db.add(wp)
if not task.collective:
task.done = True
self.db.add(task)
self.db.commit()
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
return user_post
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
return user_message
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
task = self.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"Task payload type mismatch: {type(task_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
if rating.rating < task_payload.scale.min or rating.rating > task_payload.scale.max:
raise OasstError(
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
f"Invalid rating value: {rating.rating=} not in {task_payload.scale=}",
OasstErrorCode.RATING_OUT_OF_RANGE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
reaction = self.insert_reaction(message.id, reaction_payload)
if not task.collective:
task.done = True
self.db.add(task)
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
task_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
task.payload.payload
)
match type(work_payload):
match type(task_payload):
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(work_payload.replies)
num_replies = len(task_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case _:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
self,
task: protocol_schema.Task,
thread_id: UUID = None,
parent_post_id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
@@ -271,8 +275,8 @@ class PromptRepository:
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
@@ -280,8 +284,8 @@ class PromptRepository:
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
@@ -293,81 +297,85 @@ class PromptRepository:
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert wp.id == task.id
return wp
assert task.id == task.id
return task
def insert_work_package(
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
thread_id: UUID = None,
parent_post_id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
task = Task(
id=id,
person_id=self.person_id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(wp)
self.db.add(task)
self.db.commit()
self.db.refresh(wp)
return wp
self.db.refresh(task)
return task
def insert_post(
def insert_message(
self,
*,
post_id: UUID,
frontend_post_id: str,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
thread_id: UUID,
workpackage_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.PostPayload,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Post:
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
post = Post(
id=post_id,
message = Message(
id=message_id,
parent_id=parent_id,
thread_id=thread_id,
workpackage_id=workpackage_id,
person_id=self.person_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_post_id=frontend_post_id,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(post)
self.db.add(message)
self.db.commit()
self.db.refresh(post)
return post
self.db.refresh(message)
return message
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
container = PayloadContainer(payload=payload)
reaction = PostReaction(
work_package_id=work_package_id,
person_id=self.person_id,
reaction = MessageReaction(
task_id=task_id,
user_id=self.user_id,
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
@@ -383,108 +391,338 @@ class PromptRepository:
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_post_id:
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
model.post_id = text_labels.post_id
if text_labels.has_message_id:
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
model.message_id = text_labels.message_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
"""
Loads all posts of a random thread.
Loads all messages of a random message_tree.
:param require_role: If set loads only thread which has
at least one post with given role.
:param require_role: If set loads only message_tree which has
at least one message with given role.
"""
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
if require_role:
distinct_threads = distinct_threads.filter(Post.role == require_role)
distinct_threads = distinct_threads.subquery()
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
distinct_message_trees = distinct_message_trees.subquery()
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
return thread_posts
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
return message_tree_messages
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
"""
Picks a random linear conversation starting from any root post
and ending somewhere in the thread, possibly at the root itself.
Picks a random linear conversation starting from any root message
and ending somewhere in the message_tree, possibly at the root itself.
:param last_post_role: If set will form a conversation ending with a post
:param last_message_role: If set will form a conversation ending with a message
created by this role. Necessary for the tasks like "user_reply" where
the user should reply as a human and hence the last message of the conversation
needs to have "assistant" role.
"""
thread_posts = self.fetch_random_thread(last_post_role)
if not thread_posts:
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
if last_post_role:
conv_posts = [p for p in thread_posts if p.role == last_post_role]
conv_posts = [random.choice(conv_posts)]
messages_tree = self.fetch_random_message_tree(last_message_role)
if not messages_tree:
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
if last_message_role:
conv_messages = [m for m in messages_tree if m.role == last_message_role]
conv_messages = [random.choice(conv_messages)]
else:
conv_posts = [random.choice(thread_posts)]
thread_posts = {p.id: p for p in thread_posts}
conv_messages = [random.choice(messages_tree)]
messages_tree = {m.id: m for m in messages_tree}
while True:
if not conv_posts[-1].parent_id:
if not conv_messages[-1].parent_id:
# reached the start of the conversation
break
parent_post = thread_posts[conv_posts[-1].parent_id]
conv_posts.append(parent_post)
parent_message = messages_tree[conv_messages[-1].parent_id]
conv_messages.append(parent_message)
return list(reversed(conv_posts))
return list(reversed(conv_messages))
def fetch_random_initial_prompts(self, size: int = 5):
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return posts
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages
def fetch_thread(self, thread_id: UUID):
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
def fetch_message_tree(self, message_tree_id: UUID):
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
parent = self.db.query(Post.id).filter(Post.children_count > 1)
if post_role:
parent = parent.filter(Post.role == post_role)
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
"""
Fetch a conversation with multiple possible replies to it.
This function finds a random message with >1 replies,
forms a conversation from the corresponding message tree root up to this message
and fetches up to max_size possible replies in continuation to this conversation.
"""
parent = self.db.query(Message.id).filter(Message.children_count > 1)
if message_role:
parent = parent.filter(Message.role == message_role)
parent = parent.order_by(func.random()).limit(1)
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
replies = (
self.db.query(Message).filter(Message.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
)
if not replies:
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
thread = self.fetch_thread(replies[0].thread_id)
thread = {p.id: p for p in thread}
thread_posts = [thread[replies[0].parent_id]]
message_tree = self.fetch_message_tree(replies[0].message_tree_id)
message_tree = {p.id: p for p in message_tree}
conversation = [message_tree[replies[0].parent_id]]
while True:
if not thread_posts[-1].parent_id:
if not conversation[-1].parent_id:
# reached start of the conversation
break
parent_post = thread[thread_posts[-1].parent_id]
thread_posts.append(parent_post)
parent_message = message_tree[conversation[-1].parent_id]
conversation.append(parent_message)
thread_posts = reversed(thread_posts)
conversation = reversed(conversation)
return thread_posts, replies
return conversation, replies
def fetch_post(self, post_id: UUID) -> Optional[Post]:
return self.db.query(Post).filter(Post.id == post_id).one()
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
self.validate_post_id(post_id)
wp = self.fetch_workpackage_by_postid(post_id)
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not wp:
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not allow_personal_tasks and not wp.collective:
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
if wp.done:
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
wp.done = True
self.db.add(wp)
task.done = True
self.db.add(task)
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
Pick messages from a collection so that the result makes a linear conversation
starting from a message tree root and up to the given message.
Returns an ordered list of messages starting from the message tree root.
"""
if isinstance(messages, list):
messages = {m.id: m for m in messages}
if not isinstance(messages, dict):
# This should not normally happen
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
conv = [last_message]
while conv[-1].parent_id:
if conv[-1].parent_id not in messages:
# Can't form a continuous conversation
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
parent_message = messages[conv[-1].parent_id]
conv.append(parent_message)
return list(reversed(conv))
def fetch_message_conversation(self, message: Message | UUID) -> list[Message]:
"""
Fetch a conversation from the tree root and up to this message.
"""
if isinstance(message, UUID):
message = self.fetch_message(message)
tree_messages = self.fetch_message_tree(message.message_tree_id)
return self.trace_conversation(tree_messages, message)
def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
"""
Fetch message tree this message belongs to.
"""
if isinstance(message, UUID):
message = self.fetch_message(message)
return self.fetch_message_tree(message.message_tree_id)
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
"""
Get all direct children of this message
"""
if isinstance(message, Message):
message = message.id
children = self.db.query(Message).filter(Message.parent_id == message).all()
return children
@staticmethod
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
children = defaultdict(list)
for msg in messages:
children[msg.parent_id].append(msg)
def _traverse_subtree(m: Message):
for child in children[m.id]:
yield child
yield from _traverse_subtree(child)
return list(_traverse_subtree(root))
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
"""
Find all descendant messages to this message.
This function creates a subtree of messages starting from given root message.
"""
if isinstance(message, UUID):
message = self.fetch_message(message)
desc = self.db.query(Message).filter(
Message.message_tree_id == message.message_tree_id, Message.depth > message.depth
)
if max_depth is not None:
desc = desc.filter(Message.depth <= max_depth)
desc = desc.all()
return self.trace_descendants(message, desc)
def fetch_longest_conversation(self, message: Message | UUID) -> list[Message]:
tree = self.fetch_tree_from_message(message)
max_message = max(tree, key=lambda m: m.depth)
return self.trace_conversation(tree, max_message)
def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Message, list[Message]]:
tree = self.fetch_tree_from_message(message)
max_message = max(tree, key=lambda m: m.children_count)
return max_message, [m for m in tree if m.parent_id == max_message.id]
def query_messages(
self,
user_id: Optional[UUID] = None,
username: Optional[str] = None,
api_client_id: Optional[UUID] = None,
desc: bool = True,
limit: Optional[int] = 10,
start_date: Optional[datetime.datetime] = None,
end_date: Optional[datetime.datetime] = None,
only_roots: bool = False,
deleted: Optional[bool] = None,
) -> list[Message]:
if not self.api_client.trusted and not api_client_id:
# Let unprivileged api clients query their own messages without api_client_id being set
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
messages = self.db.query(Message)
if user_id:
messages = messages.filter(Message.user_id == user_id)
if username:
messages = messages.join(User)
messages = messages.filter(User.username == username)
if api_client_id:
messages = messages.filter(Message.api_client_id == api_client_id)
if start_date:
messages = messages.filter(Message.created_date >= start_date)
if end_date:
messages = messages.filter(Message.created_date < end_date)
if only_roots:
messages = messages.filter(Message.parent_id.is_(None))
if deleted is not None:
messages = messages.filter(Message.deleted == deleted)
if desc:
messages = messages.order_by(Message.created_date.desc())
else:
messages = messages.order_by(Message.created_date.asc())
if limit is not None:
messages = messages.limit(limit)
# TODO: Pagination could be great at some point
return messages.all()
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
"""
Marks deleted messages and all their descendants.
"""
if isinstance(messages, (Message, UUID)):
messages = [messages]
ids = []
for message in messages:
if isinstance(message, UUID):
ids.append(message)
elif isinstance(message, Message):
ids.append(message.id)
else:
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
self.db.execute(query)
parent_ids = ids
if recursive:
while parent_ids:
query = (
update(Message).filter(Message.parent_id.in_(parent_ids)).values(deleted=True).returning(Message.id)
)
parent_ids = self.db.execute(query).scalars().all()
self.db.commit()
def get_stats(self) -> SystemStats:
"""
Get data stats such as number of all messages in the system,
number of deleted and active messages and number of message trees.
"""
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
query = deleted.union_all(nthreads)
result = {k: v for k, v in query.all()}
return SystemStats(
all=result.get(True, 0) + result.get(False, 0),
active=result.get(False, 0),
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
+1
View File
@@ -1,5 +1,6 @@
alembic==1.8.1
fastapi==0.88.0
fastapi-limiter==0.1.5
loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
+6 -6
View File
@@ -16,8 +16,8 @@ Setup requires a few steps:
copilot app init --domain your_domain.com
```
This will initialize and register a variety of URLs with your
`your_domain.com`. Replace with a proper domain to setup SSL certificates.
This will initialize and register a variety of URLs with your `your_domain.com`.
Replace with a proper domain to setup SSL certificates.
```sh
copilot env deploy
@@ -29,10 +29,10 @@ This will create a variety of aws roles and services needed for deployment.
copilot deploy
```
This will depoy the services but it won't be 100% ready for usage. Before
being ready, we have to inspect the AWS Secrets manager and extract out the
database credentials. Read those credentials then put them, and a few other
secrets, in a `secrets.yml` file like the following:
This will depoy the services but it won't be 100% ready for usage. Before being
ready, we have to inspect the AWS Secrets manager and extract out the database
credentials. Read those credentials then put them, and a few other secrets, in a
`secrets.yml` file like the following:
```yaml
DATABASE_URL:
+38
View File
@@ -0,0 +1,38 @@
# The manifest for the "api" service.
# Read the full specification for the "Load Balanced Web Service" type at:
# https://aws.github.io/copilot-cli/docs/manifest/lb-web-service/
name: api
type: Load Balanced Web Service
http:
path: "/"
healthcheck:
path: "/docs"
image:
build:
dockerfile: docker/Dockerfile.backend
context: ./
port: 8080
cpu: 256
memory: 512
platform: linux/x86_64
count: 1
exec: true
network:
connect: true
environments:
staging:
variables:
# Note: this has to be a valid JSON list for Pydantic to parse it.
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
DEBUG_ALLOW_ANY_API_KEY: True
DEBUG_SKIP_API_KEY_CHECK: True
MAX_WORKERS: 1
secrets:
# Note: URI, not URL.
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
-144
View File
@@ -1,144 +0,0 @@
Parameters:
App:
Type: String
Description: Your application's name.
Env:
Type: String
Description: The environment name your service, job, or workflow is being deployed to.
Name:
Type: String
Description: The name of the service, job, or workflow being deployed.
# Customize your Aurora Serverless cluster by setting the default value of the following parameters.
webclusterDBName:
Type: String
Description: The name of the initial database to be created in the Aurora Serverless v2 cluster.
Default: oassist_web
# Cannot have special characters
# Naming constraints: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html#RDS_Limits.Constraints
Mappings:
webclusterEnvScalingConfigurationMap:
staging:
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
All:
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
Resources:
webclusterDBSubnetGroup:
Type: "AWS::RDS::DBSubnetGroup"
Properties:
DBSubnetGroupDescription: Group of Copilot private subnets for Aurora Serverless v2 cluster.
SubnetIds:
!Split [",", { "Fn::ImportValue": !Sub "${App}-${Env}-PrivateSubnets" }]
webclusterSecurityGroup:
Metadata:
"aws:copilot:description": "A security group for your workload to access the Aurora Serverless v2 cluster webcluster"
Type: "AWS::EC2::SecurityGroup"
Properties:
GroupDescription: !Sub "The Security Group for ${Name} to access Aurora Serverless v2 cluster webcluster."
VpcId:
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
Tags:
- Key: Name
Value: !Sub "copilot-${App}-${Env}-${Name}-Aurora"
webclusterDBClusterSecurityGroup:
Metadata:
"aws:copilot:description": "A security group for your Aurora Serverless v2 cluster webcluster"
Type: AWS::EC2::SecurityGroup
Properties:
GroupDescription: The Security Group for the Aurora Serverless v2 cluster.
SecurityGroupIngress:
- ToPort: 5432
FromPort: 5432
IpProtocol: tcp
Description: !Sub "From the Aurora Security Group of the workload ${Name}."
SourceSecurityGroupId: !Ref webclusterSecurityGroup
VpcId:
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
webclusterAuroraSecret:
Metadata:
"aws:copilot:description": "A Secrets Manager secret to store your DB credentials"
Type: AWS::SecretsManager::Secret
Properties:
Description: !Sub Aurora main user secret for ${AWS::StackName}
GenerateSecretString:
SecretStringTemplate: '{"username": "postgres"}'
GenerateStringKey: "password"
ExcludePunctuation: true
IncludeSpace: false
PasswordLength: 16
webclusterDBClusterParameterGroup:
Metadata:
"aws:copilot:description": "A DB parameter group for engine configuration values"
Type: "AWS::RDS::DBClusterParameterGroup"
Properties:
Description: !Ref "AWS::StackName"
Family: "aurora-postgresql14"
Parameters:
client_encoding: "UTF8"
webclusterDBCluster:
Metadata:
"aws:copilot:description": "The webcluster Aurora Serverless v2 database cluster"
Type: "AWS::RDS::DBCluster"
Properties:
MasterUsername:
!Join [
"",
[
"{{resolve:secretsmanager:",
!Ref webclusterAuroraSecret,
":SecretString:username}}",
],
]
MasterUserPassword:
!Join [
"",
[
"{{resolve:secretsmanager:",
!Ref webclusterAuroraSecret,
":SecretString:password}}",
],
]
DatabaseName: !Ref webclusterDBName
Engine: "aurora-postgresql"
EngineVersion: "14.4"
DBClusterParameterGroupName: !Ref webclusterDBClusterParameterGroup
DBSubnetGroupName: !Ref webclusterDBSubnetGroup
Port: 5432
VpcSecurityGroupIds:
- !Ref webclusterDBClusterSecurityGroup
ServerlessV2ScalingConfiguration:
# Replace "All" below with "!Ref Env" to set different autoscaling limits per environment.
MinCapacity:
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMinCapacity]
MaxCapacity:
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMaxCapacity]
webclusterDBWriterInstance:
Metadata:
"aws:copilot:description": "The webcluster Aurora Serverless v2 writer instance"
Type: "AWS::RDS::DBInstance"
Properties:
DBClusterIdentifier: !Ref webclusterDBCluster
DBInstanceClass: db.serverless
Engine: "aurora-postgresql"
PromotionTier: 1
AvailabilityZone: !Select
- 0
- !GetAZs
Ref: AWS::Region
webclusterSecretAuroraClusterAttachment:
Type: AWS::SecretsManager::SecretTargetAttachment
Properties:
SecretId: !Ref webclusterAuroraSecret
TargetId: !Ref webclusterDBCluster
TargetType: AWS::RDS::DBCluster
Outputs:
webclusterSecret: # injected as WEBCLUSTER_SECRET environment variable by Copilot.
Description: "The JSON secret that holds the database username and password. Fields are 'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier' and 'engine'"
Value: !Ref webclusterAuroraSecret
webclusterSecurityGroup:
Description: "The security group to attach to the workload."
Value: !Ref webclusterSecurityGroup
+1 -1
View File
@@ -26,6 +26,7 @@ environments:
staging:
variables:
NEXTAUTH_URL: https://web.staging.open-assistant.surfacedata.org
FASTAPI_URL: https://api.staging.open-assistant.surfacedata.org
secrets:
DATABASE_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/DATABASE_URL
@@ -37,5 +38,4 @@ secrets:
EMAIL_SERVER_USER: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_SERVER_USER
EMAIL_FROM: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_FROM
FASTAPI_KEY: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_KEY
FASTAPI_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_URL
NEXTAUTH_SECRET: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/NEXTAUTH_SECRET
+7
View File
@@ -0,0 +1,7 @@
BOT_TOKEN=<discord bot token>
DECLARE_GLOBAL_COMMANDS=<testing guild id>
OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY=""
+7
View File
@@ -1,3 +1,10 @@
.env
*.egg-info/
__pycache__/
.venv
.nox
.env
# Database files
*.db
+159 -8
View File
@@ -1,20 +1,171 @@
# Open-Assistant Data Collection Discord Bot
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
This bot collects human feedback to create a dataset for RLHF-alignment of an
assistant chat bot based on a large language model. You and other people can
teach the bot how to respond to user requests by demonstration and by ranking
the bot's outputs. If you want to learn more about RLHF please refer
[to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
## Invite official bot
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
To add the official Open-Assistant data collection bot to your discord server
[click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot%20applications.commands).
The bot needs access to read the contents of user text messages.
## Bot token for development
## Contributing
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the
[large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
### Bot Setup
The simplest way to configure the token is via an `.env` file:
1. Create a new discord application at the
[Discord Developer Portal](https://discord.com/developers/applications)
1. Go to the "Bot" tab and create a new bot
1. Scroll down to "Privileged Gateway Intents" and enable the following options:
- Server Members Intent
- Presence Intent
- Message Content Intent
This page also contains the bot token, which you will need to add to the `.env`
file later.
2. Go to the "OAuth2" tab scroll to "Default Authorization Link"
3. Set "AUTHORIZATION METHOD" to "In-app Authorization"
4. Select the "bot" and "applications.commands" scopes
5. For testing and local development, it's easiest to set "BOT PERMISSIONS" to
"Administrator"
Remember to save your changes.
6. Copy the "CLIENT ID" from the top of the page and replace it in the link
below to invite your bot.
```
BOT_TOKEN=XYZABC123...
https://discord.com/oauth2/authorize?client_id=YOUR_CLIENT_ID_HERE&permissions=8&scope=bot%20applications.commands
```
### Environment Setup
To run the bot:
Install dependency module `oasst-shared`
```bash
cd oasst-shared
pip install -e .
```
```bash
cp .env.example .env
# edit .env and add your bot token and other values
# BOT_TOKEN is given by the discord developer portal when you create a bot
# DECLARE_GLOBAL_COMMANDS is the id of the server where you added the bot (right click on the server icon and copy id)
# OWNER_ID can be leave as an empty list
python -V # 3.10
pip install -r requirements.txt
# in the discord-bot folder
python -m bot
```
Before you push, make sure the `pre-commit` hooks are installed and run
successfully.
```bash
pip install pre-commit
pre-commit install
...
git add .
git commit -m "<good commit message>"
# if the pre-commit fails
git add .
git commit -m "<good commit message>"
```
### Resources
#### Structure
Important files
```graphql
.env # Environment variables
.env.example # Example environment variables
CONTRIBUTING.md # This file
README.md # Project readme
EXAMPLES.md # Examples for commands and listeners
requirements.txt # Requirements
bot/
__main__.py # Entrypoint
api_client.py # API Client for interacting with the backend
bot.py # Main bot class
settings.py # Settings and secrets
utils.py # Utility Functions
db/ # Database related code
database.db # SQLite database
schema.sql # SQL schema
schemas.py # Python table schemas
extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
work.py # Task handling logic <-- most important file
guild_settings.py # Server specific settings
hot_reload.py # Utility for hot reload extensions during development
```
#### Adding a new command/listener
1. Create a new file in the `extensions` folder
2. Copy the template below
```py
# -*- coding: utf-8 -*-
"""My plugin."""
import lightbulb
plugin = lightbulb.Plugin("MyPlugin")
# Add your commands here
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
```
#### Docs
Discord
- [Discord API Reference](https://discord.com/developers/docs/intro)
`hikari` (main framework)
- [Hikari Repo](https://github.com/hikari-py/hikari)
- [Hikari Docs](https://docs.hikari-py.dev/en/latest/)
`lightbulb` (command handler)
- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb)
- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/)
`miru` (component handler: buttons, modals, etc... )
- [Miru Repo](https://github.com/HyperGH/hikari-miru)
-17
View File
@@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
from bot import OpenAssistantBot
from bot_settings import settings
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
if __name__ == "__main__":
bot = OpenAssistantBot(
settings.BOT_TOKEN,
bot_channel_name=settings.BOT_CHANNEL_NAME,
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
-79
View File
@@ -1,79 +0,0 @@
# -*- coding: utf-8 -*-
import enum
from typing import Optional, Type
import requests
from oasst_shared.schemas import protocol as protocol_schema
class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class ApiClient:
def __init__(self, backend_url: str, api_key: str):
self.backend_url = backend_url
self.api_key = api_key
task_models_map: dict[str, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
self.task_models_map = task_models_map
def post(self, path: str, json: dict) -> dict:
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return response.json()
def _parse_task(self, data: dict) -> protocol_schema.Task:
if not isinstance(data, dict):
raise ValueError("dict expected")
task_type = data.get("type")
if task_type not in self.task_models_map:
raise RuntimeError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data)
def fetch_task(
self,
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
) -> protocol_schema.Task:
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
data = self.post("/api/v1/tasks/", req.dict())
return self._parse_task(data)
def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
def ack_task(self, task_id: str, post_id: str) -> None:
req = protocol_schema.TaskAck(post_id=post_id)
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
def nack_task(self, task_id: str, reason: str) -> None:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
-283
View File
@@ -1,283 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)
+1
View File
@@ -0,0 +1 @@
"""The official Open-Assistant Discord Bot."""
+16
View File
@@ -0,0 +1,16 @@
"""Entry point for the bot."""
import logging
import os
from bot.bot import bot
logger = logging.getLogger(__name__)
if __name__ == "__main__":
if os.name != "nt":
import uvloop
uvloop.install()
logger.info("Starting bot")
bot.run()
+126
View File
@@ -0,0 +1,126 @@
"""Bot logic."""
from datetime import datetime
import aiosqlite
import hikari
import lightbulb
import miru
from bot.settings import Settings
from bot.utils import mention
from oasst_shared.api_client import OasstApiClient
settings = Settings()
# TODO: Revisit cache settings
bot = lightbulb.BotApp(
token=settings.bot_token,
logs="DEBUG",
prefix=settings.prefix,
default_enabled_guilds=settings.declare_global_commands,
owner_ids=settings.owner_ids,
intents=hikari.Intents.ALL,
)
@bot.listen()
async def on_starting(event: hikari.StartingEvent):
"""Setup."""
miru.install(bot) # component handler
bot.load_extensions_from("./bot/extensions") # load extensions
bot.d.db = await aiosqlite.connect("./bot/db/database.db")
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
await bot.d.db.commit()
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
# A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs).
# Either both are `None` or both are not `None`.
# If both are `None`, the user is not currently selecting a task.
# TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow
bot.d.currently_working = {}
@bot.listen()
async def on_stopping(event: hikari.StoppingEvent):
"""Cleanup."""
await bot.d.db.close()
await bot.d.oasst_api.close()
async def _send_error_embed(
content: str, exception: lightbulb.errors.LightbulbError | BaseException, ctx: lightbulb.Context
) -> None:
ctx.command
embed = hikari.Embed(
title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }",
description=content,
color=0xFF0000,
timestamp=datetime.now().astimezone(),
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
await ctx.respond(embed=embed)
@bot.listen(lightbulb.CommandErrorEvent)
async def on_error(event: lightbulb.CommandErrorEvent) -> None:
"""Error handler for the bot."""
# Unwrap the exception to get the original cause
exc = event.exception.__cause__ or event.exception
ctx = event.context
if not ctx.bot.rest.is_alive:
return
if isinstance(event.exception, lightbulb.CommandInvocationError):
if not event.context.command:
await _send_error_embed("Something went wrong", exc, ctx)
else:
await _send_error_embed(
f"Something went wrong during invocation of command `{event.context.command.name}`.", exc, ctx
)
raise event.exception
# Not an owner
if isinstance(exc, lightbulb.NotOwner):
await _send_error_embed("You are not the owner of this bot.", exc, ctx)
# Command is on cooldown
elif isinstance(exc, lightbulb.CommandIsOnCooldown):
await _send_error_embed(f"This command is on cooldown. Retry in `{exc.retry_after:.2f}` seconds.", exc, ctx)
# Missing permissions
elif isinstance(exc, lightbulb.errors.MissingRequiredPermission):
await _send_error_embed(
f"You do not have permission to use this command. Missing permissions: {exc.missing_perms}", exc, ctx
)
# Missing roles
elif isinstance(exc, lightbulb.errors.MissingRequiredRole):
assert event.context.guild_id is not None # Roles only exist in guilds
await _send_error_embed(
f"You do not have the correct role to use this command. Missing role(s): {[mention(r, 'role') for r in exc.missing_roles]}",
exc,
ctx,
)
# Only a guild command
elif isinstance(exc, lightbulb.errors.OnlyInGuild):
await _send_error_embed("This command can only be run in servers.", exc, ctx)
# Only a DM command
elif isinstance(exc, lightbulb.errors.OnlyInDM):
await _send_error_embed("This command can only be run in DMs.", exc, ctx)
# Not enough arguments
elif isinstance(exc, lightbulb.errors.NotEnoughArguments):
await _send_error_embed(
f"Not enough arguments were supplied to the command. {[opt.name for opt in exc.missing_options]}", exc, ctx
)
# Bot missing permission
elif isinstance(exc, lightbulb.errors.BotMissingRequiredPermission):
await _send_error_embed(
f"The bot does not have the correct permission(s) to execute this command. Missing permissions: {exc.missing_perms}",
exc,
ctx,
)
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx)
elif isinstance(exc, lightbulb.errors.CommandNotFound):
await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.")
else:
raise exc
+5
View File
@@ -0,0 +1,5 @@
-- Sqlite3 schema for the bot
CREATE TABLE IF NOT EXISTS guild_settings (
guild_id BIGINT NOT NULL PRIMARY KEY,
log_channel_id BIGINT
);
+27
View File
@@ -0,0 +1,27 @@
"""Database schemas."""
import typing as t
from aiosqlite import Connection, Row
from pydantic import BaseModel
class GuildSettings(BaseModel):
"""Guild settings."""
guild_id: int
log_channel_id: int | None
@classmethod
def parse_obj(cls, obj: Row) -> "GuildSettings":
"""Deserialize a Row object from aiosqlite into a GuildSettings object."""
return cls(guild_id=obj[0], log_channel_id=obj[1])
@classmethod
async def from_db(cls, conn: Connection, guild_id: int) -> t.Optional["GuildSettings"]:
async with conn.cursor() as cursor:
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
row = await cursor.fetchone()
if row is None:
return None
return cls.parse_obj(row)
+4
View File
@@ -0,0 +1,4 @@
"""Extensions for the bot.
See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
"""
@@ -0,0 +1,104 @@
"""Guild settings."""
import hikari
import lightbulb
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import mention
from lightbulb.utils import permissions_in
from loguru import logger
plugin = lightbulb.Plugin("GuildSettings")
plugin.add_checks(lightbulb.guild_only)
plugin.add_checks(lightbulb.has_guild_permissions(hikari.Permissions.MANAGE_GUILD))
@plugin.command
@lightbulb.command("settings", "Bot settings for the server.")
@lightbulb.implements(lightbulb.SlashCommandGroup)
async def settings(_: lightbulb.SlashContext) -> None:
"""Bot settings for the server."""
# This will never execute because it is a group
pass
@settings.child
@lightbulb.command("get", "Get all the guild settings.")
@lightbulb.implements(lightbulb.SlashSubCommand)
async def get(ctx: lightbulb.SlashContext) -> None:
"""Get one of or all the guild settings."""
conn: Connection = ctx.bot.d.db
assert ctx.guild_id is not None # `guild_only` check
async with conn.cursor() as cursor:
# Get all settings
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (ctx.guild_id,))
row = await cursor.fetchone()
if row is None:
logger.warning(f"No guild settings for {ctx.guild_id}")
await ctx.respond("No settings found for this guild.")
return
guild_settings = GuildSettings.parse_obj(row)
# Respond with all
# TODO: Embed
await ctx.respond(
f"""\
**Guild Settings**
`log_channel`: {
mention(guild_settings.log_channel_id, "channel")
if guild_settings.log_channel_id else 'not set'}
"""
)
@settings.child
@lightbulb.option("channel", "The channel to use.", hikari.TextableGuildChannel)
@lightbulb.command("log_channel", "Set the channel that the bot logs task and label completions in.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashSubCommand)
async def log_channel(ctx: lightbulb.SlashContext) -> None:
"""Set the channel that the bot logs task and label completions in."""
channel: hikari.TextableGuildChannel = ctx.options.channel
conn: Connection = ctx.bot.d.db
assert ctx.guild_id is not None # `guild_only` check
# Check if the bot can send messages in that channel
assert isinstance(channel, hikari.InteractionChannel) # Slash commands are interactions
me = ctx.bot.cache.get_me() or await ctx.bot.rest.fetch_my_user()
own_member = ctx.bot.cache.get_member(ctx.guild_id, me.id) or await ctx.bot.rest.fetch_member(ctx.guild_id, me.id)
# Get the channel from the cache if it is there, otherwise fetch it
if (ch := ctx.bot.cache.get_guild_channel(channel.id)) is None:
ch = {ch.id: ch for ch in await ctx.bot.rest.fetch_guild_channels(channel.id)}[channel.id]
if not isinstance(ch, hikari.GuildTextChannel):
await ctx.respond(f"{ch.mention} is not a text channel.")
return
# if the bot's permissions for this channel don't contain SEND_MESSAGE
# This will also filter out categories and voice channels
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
return
await ctx.respond(f"Setting `log_channel` to {channel.mention}.")
# update the database
async with conn.cursor() as cursor:
await cursor.execute(
"INSERT OR REPLACE INTO guild_settings (guild_id, log_channel_id) VALUES (?, ?)",
(ctx.guild_id, channel.id),
)
await conn.commit()
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+63
View File
@@ -0,0 +1,63 @@
"""Hot reload plugin."""
from glob import glob
import hikari
import lightbulb
from loguru import logger
plugin = lightbulb.Plugin(
"HotReloadPlugin",
)
plugin.add_checks(lightbulb.owner_only)
EXTENSIONS_FOLDER = "bot/extensions"
def _get_extensions() -> list[str]:
# Recursively get all the .py files in the extensions directory not starting with an `_`.
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]:
# Check that the option is a string.
if not isinstance(option.value, str):
raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`")
exts = _get_extensions()
return [ext for ext in exts if option.value in ext]
@plugin.command
@lightbulb.option(
"plugin",
"The plugin to reload. Leave empty to reload all plugins.",
autocomplete=_plugin_autocomplete,
required=False,
default=None,
)
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def reload(ctx: lightbulb.SlashContext):
"""Reload a plugin or all plugins."""
# If the plugin option is None, reload all plugins.
if ctx.options.plugin is None:
ctx.bot.reload_extensions(*_get_extensions())
await ctx.respond("Reloaded all plugins.")
logger.info("Reloaded all plugins.")
# Otherwise, reload the specified plugin.
else:
ctx.bot.reload_extensions(ctx.options.plugin)
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
logger.info(f"Reloaded `{ctx.options.plugin}`.")
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+179
View File
@@ -0,0 +1,179 @@
"""Hot reload plugin."""
import typing as t
from datetime import datetime
import hikari
import lightbulb
import miru
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from loguru import logger
plugin = lightbulb.Plugin(
"TextLabels",
)
plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds
DISCORD_GRAY = 0x2F3136
def clamp(num: float) -> float:
"""Clamp a number between 0 and 1."""
return min(max(0.0, num), 1.0)
class LabelModal(miru.Modal):
"""Modal for submitting text labels."""
def __init__(self, label: str, content: str, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
self.label = label
self.original_content = content
# Add the text of the message to the modal
self.content = miru.TextInput(
label="Text", style=hikari.TextInputStyle.PARAGRAPH, value=content, required=True, row=1
)
self.add_item(self.content)
value = miru.TextInput(label="Value", placeholder="Enter a value between 0 and 1", required=True, row=2)
async def callback(self, context: miru.ModalContext) -> None:
val = float(self.value.value) if self.value.value else 0.0
val = clamp(val)
edited = self.content.value != self.original_content
await context.respond(
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
flags=hikari.MessageFlag.EPHEMERAL,
)
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
# Send a notification to the log channel
assert context.guild_id is not None # `guild_only` check
conn: Connection = context.bot.d.db # type: ignore
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
if guild_settings is None or guild_settings.log_channel_id is None:
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
return
embed = (
hikari.Embed(
title="Message Label",
description=f"{context.author.mention} labeled a message as `{self.label}`.",
timestamp=datetime.now().astimezone(),
color=0x00FF00,
)
.set_author(name=context.author.username, icon=context.author.avatar_url)
.add_field("Total Labeled Message", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
)
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
assert isinstance(channel, hikari.TextableChannel)
await channel.send(embed=embed)
class LabelSelect(miru.View):
"""Select menu for selecting a label.
The current labels are:
- contains toxic language
- encourages illegal activity
- good quality
- bad quality
- is spam
"""
def __init__(self, content: str, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
self.content = content
@miru.select(
options=[
hikari.SelectMenuOption(
label="Toxic Language",
value="toxic_language",
description="The message contains toxic language.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Illegal Activity",
value="illegal_activity",
description="The message encourages illegal activity.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Good Quality",
value="good_quality",
description="The message is good quality.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Bad Quality",
value="bad_quality",
description="The message is bad quality.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Spam",
value="spam",
description="The message is spam.",
is_default=False,
emoji=None,
),
],
min_values=1,
max_values=1,
)
async def label_select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
"""Handle the select menu."""
label = select.values[0]
modal = LabelModal(label, self.content, title=f"Text Label: {label}", timeout=60)
await modal.send(ctx.interaction)
await modal.wait()
self.stop()
@plugin.command
@lightbulb.command("Label Message", "Label a message")
@lightbulb.implements(lightbulb.MessageCommand)
async def label_message_text(ctx: lightbulb.MessageContext):
"""Label a message."""
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
# so the select menu will open the modal
msg: hikari.Message = ctx.options.target
# Exit if the message is empty
if not msg.content:
await ctx.respond("Cannot label an empty message.", flags=hikari.MessageFlag.EPHEMERAL)
return
# Send the select menu
# The modal will be opened from the select menu interaction
embed = hikari.Embed(title="Label Message", description="Select a label for the message.", color=DISCORD_GRAY)
label_select_view = LabelSelect(
msg.content,
timeout=60,
)
resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
await label_select_view.start(await resp.message())
await label_select_view.wait()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
@@ -0,0 +1,300 @@
"""Task plugin for testing different data collection methods."""
# TODO: Delete this once user input method has been decided for final bot.
import asyncio
import typing as t
from datetime import datetime, timedelta
import hikari
import lightbulb
import lightbulb.decorators
import miru
from bot.utils import format_time
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("TaskPlugin")
MAX_TASK_TIME = 60 * 60
MAX_TASK_ACCEPT_TIME = 60
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_thread(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
typ: str = ctx.options.type
# Create a thread for the task
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
# Send the task in the thread
await thread.send(
f"""\
Please complete the task.
Sample Task
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
"""
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.GuildMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
)
await ctx.respond(f"Received message: {event.message.content}")
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
finally:
await thread.delete()
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def task_dm(ctx: lightbulb.Context):
"""Request a task from the backend."""
await ctx.respond("Please complete the task in your DMs")
# Send the task in the dm
await ctx.author.send(
f"""\
Please complete the task.
Sample Task
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
"""
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id,
)
await ctx.respond(f"Received message: {event.message.content}")
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
class TaskModal(miru.Modal):
"""Modal for submitting a task."""
response = miru.TextInput(
label="Response",
placeholder="Enter your response!",
required=True,
style=hikari.TextInputStyle.PARAGRAPH,
row=2,
)
async def callback(self, context: miru.ModalContext) -> None:
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
class ModalView(miru.View):
"""View for opening a modal."""
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.modal_title = modal_title
self.task = task
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
modal = TaskModal(title=self.modal_title)
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
await ctx.respond_with_modal(modal)
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_modal(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
# typ: str = ctx.options.type
view = ModalView(
modal_title="Assistant Response",
task="Please explain the moon landing to a six year old.",
timeout=MAX_TASK_TIME,
)
resp = await ctx.respond(
"Task - Respond to the prompt as if you were the Assistant:",
flags=hikari.MessageFlag.EPHEMERAL,
components=view,
)
await view.start(await resp.message())
class RatingView(miru.View):
"""View for rating a task."""
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.presses: list[str] = []
def _close_if_all_pressed(self) -> None:
if len(self.presses) == 5:
self.stop()
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("1")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("2")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("3")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("4")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("5")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.presses = []
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
class SelectRating(miru.View):
"""View for rating a task with a select menu."""
@miru.select(
options=[
hikari.SelectMenuOption(
label="1",
value="1",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="2",
value="2",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="3",
value="3",
description=None,
emoji=None,
is_default=False,
),
],
placeholder="Select the good responses",
min_values=0,
max_values=3,
row=3,
)
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
@plugin.command
@lightbulb.command("rating_task", "Rate stuff.")
@lightbulb.implements(lightbulb.SlashCommand)
async def rating_task(ctx: lightbulb.SlashContext):
"""Rate stuff."""
# Message Based rating
await ctx.respond(
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
)
try:
event = await ctx.bot.wait_for(
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
)
except asyncio.TimeoutError:
await ctx.respond("Timed out waiting for response")
return
if event.content is None:
await ctx.respond("No content in message")
return
ratings = event.content.replace(" ", "").split(",")
# Check if the ratings are valid
if len(ratings) != 5:
await ctx.respond("Invalid number of ratings")
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
await ctx.respond("Invalid rating")
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
# Button Based rating
view = RatingView(timeout=MAX_TASK_TIME)
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
await view.start(await resp.message())
await view.wait()
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
await resp.delete()
# Select Based rating
select_view = SelectRating(timeout=MAX_TASK_TIME)
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
await select_view.start(await resp_2.message())
await select_view.wait()
await resp_2.delete()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+461
View File
@@ -0,0 +1,461 @@
"""Work plugin for collecting user data."""
import asyncio
import typing as t
from uuid import UUID
import hikari
import lightbulb
import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.messages import (
assistant_reply_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_message,
invalid_user_input_embed,
plain_embed,
prompter_reply_message,
rank_assistant_reply_message,
rank_initial_prompts_message,
rank_prompter_reply_message,
task_complete_embed,
)
from bot.settings import Settings
from loguru import logger
from oasst_shared.api_client import OasstApiClient, TaskType
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
settings = Settings()
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.random),
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work(ctx: lightbulb.Context):
"""Create and handle a task."""
# Only send this message if started from a server
if ctx.guild_id is not None:
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
currently_working: dict[
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
] = ctx.bot.d.currently_working
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
if ctx.author.id in currently_working:
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
flags=hikari.MessageFlag.EPHEMERAL,
components=yn_view,
)
await yn_view.start(msg)
await yn_view.wait()
match yn_view.choice:
case False | None:
return
case True:
old_msg, task_id = currently_working[ctx.author.id]
if old_msg is not None:
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
map(lambda c: c, old_msg.components)
await old_msg.delete()
if task_id is not None:
await oasst_api.nack_task(task_id, reason="user cancelled")
await msg.delete()
currently_working[ctx.author.id] = (None, None)
# Create a TaskRequestType from the stringified enum value
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
logger.debug(f"Starting task_type: {task_type!r}")
try:
await _handle_task(ctx, task_type)
finally:
del currently_working[ctx.author.id]
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
# User cancelled
return
# Task action loop
completed = False
while not completed:
await ctx.author.send(embed=plain_embed("Please type your response here"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id
and not (e.message.content or "").startswith(settings.prefix),
)
except asyncio.TimeoutError:
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
valid, err_msg = _validate_user_input(event.content, task)
if not valid or event.content is None:
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
continue
logger.debug(f"Successful user input received: {event.content}")
# Confirm user input
if isinstance(task, protocol_schema.RankConversationRepliesTask):
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
logger.critical(f"Unknown task type: {task.type}")
raise ValueError(f"Unknown task type: {task.type}")
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
msg = await ctx.author.send(content, components=confirm_resp_view)
await confirm_resp_view.start(msg)
await confirm_resp_view.wait()
match confirm_resp_view.choice:
case False | None:
continue
case True:
await msg.delete() # buttons are already gone
# Send the response to the backend
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
reply = protocol_schema.MessageRanking(
message_id=str(msg_id),
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
else:
logger.critical(f"Unexpected task type received: {task.type}")
raise ValueError(f"Unexpected task type received: {task.type}")
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send(embed=plain_embed("Task completed"))
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in all the log channels that the task is complete
conn: Connection = ctx.bot.d.db
async with conn.cursor() as cursor:
await cursor.execute("SELECT log_channel_id FROM guild_settings")
log_channel_ids = await cursor.fetchall()
channels = [
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
for id in log_channel_ids
]
done_embed = task_complete_embed(task, ctx.author.mention)
# This will definitely get the bot rate limited, but that's a future problem
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
# ask the user if they want to do another task
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
await another_task_view.start(msg)
await another_task_view.wait()
match another_task_view.choice:
case False | None:
done = True
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
case True:
pass
async def _select_task(
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg = await _send_task(ctx, task, msg)
msg_id = str(msg.id)
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
return None, msg_id
async def _send_task(
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
"""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
content = initial_prompt_message(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
content = rank_initial_prompts_message(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
content = rank_prompter_reply_message(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
content = prompter_reply_message(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
content = assistant_reply_message(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
if not msg:
msg = await ctx.author.send(
content,
embed=embed,
components=view,
)
else:
await msg.edit(
content,
embed=embed,
components=view,
)
assert msg is not None
# Set the choice id as the current msg id
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
await view.start(msg)
await view.wait()
return view.choice, msg
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
"""Returns whether the user's input is valid for the task type and an error message."""
if content is None:
return False, "No input provided"
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0, "Message must be at least one character long."
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
"Message must contain numbers for all replies.",
)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
"Message must contain numbers for all prompts.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: t.Literal["accept", "next", "cancel"] | None = None
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Accept button pressed")
self.choice = "accept"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Next button pressed")
self.choice = "next"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Cancel button pressed")
self.choice = "cancel"
await ctx.message.edit(component=None)
self.stop()
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
class YesNoView(miru.View):
"""View with two buttons: yes and no.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: bool | None = None
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = True
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = False
await ctx.message.edit(component=None)
self.stop()
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+207
View File
@@ -0,0 +1,207 @@
"""All user-facing messages and embeds."""
from datetime import datetime
import hikari
from oasst_shared.schemas import protocol as protocol_schema
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
NL = "\n"
###
# Reusable 'components'
###
def _h1(text: str) -> str:
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
def _h2(text: str) -> str:
return f"__**{text}**__"
def _h3(text: str) -> str:
return f"__{text}__"
def _writing_prompt(text: str) -> str:
return f":pencil: _{text}_"
def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _response_prompt(text: str) -> str:
return f":speech_balloon: _{text}_"
def _summarize_prompt(text: str) -> str:
return f":notepad_spiral: _{text}_"
def _user(text: str | None) -> str:
return f"""\
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
"""
def _assistant(text: str | None) -> str:
return f"""\
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
"""
def _make_ordered_list(items: list[str]) -> list[str]:
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[str]) -> str:
return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> str:
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
def _hint(hint: str | None) -> str:
return f"{NL}Hint: {hint}" if hint else ""
###
# Messages
###
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
return f"""\
{_h1("INITIAL PROMPT")}
{_writing_prompt("Please provide an initial prompt to the assistant.")}
{_hint(task.hint)}
"""
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return f"""\
{_h1("RANK INITIAL PROMPTS")}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
{_ordered_list(task.prompts)}
"""
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return f"""\
{_h1("RANK PROMPTER REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_user(None)}
{_ordered_list(task.replies)}
"""
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return f"""\
{_h1("RANK ASSISTANT REPLIES")}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
{_conversation(task.conversation)}
{_assistant(None)}
{_ordered_list(task.replies)}
"""
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\
{_h1("PROMPTER REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
{_hint(task.hint)}
"""
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
return f"""\
{_h1("ASSISTANT REPLY")}
{_response_prompt("Please provide a reply to the assistant.")}
{_conversation(task.conversation)}
"""
def confirm_text_response_message(content: str) -> str:
return f"""\
{_h2("CONFIRM RESPONSE")}
> {content}
"""
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
return f"""\
{_h2("CONFIRM RESPONSE")}
{user_ranked_list}
"""
###
# Embeds
###
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
return (
hikari.Embed(
title="Task Completion",
description=f"`{task.type}` completed by {mention}",
color=hikari.Color(0x00FF00),
timestamp=datetime.now().astimezone(),
)
.add_field("Total Tasks", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
.set_footer(f"Task ID: {task.id}")
)
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
return hikari.Embed(
title="Invalid User Input",
description=error_message,
color=hikari.Color(0xFF0000),
timestamp=datetime.now().astimezone(),
)
def plain_embed(text: str) -> hikari.Embed:
return hikari.Embed(color=0x36393F, description=text)
+17
View File
@@ -0,0 +1,17 @@
"""Configuration for the bot."""
from pydantic import BaseSettings, Field
class Settings(BaseSettings):
"""Settings for the bot."""
bot_token: str = Field(env="BOT_TOKEN", default="")
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
prefix: str = Field(env="PREFIX", default="/")
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
class Config(BaseSettings.Config):
env_file = ".env"
case_sensitive = False
+40
View File
@@ -0,0 +1,40 @@
"""Utility functions."""
import typing as t
from datetime import datetime
import hikari
def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str:
"""Format a datetime object into the discord time format.
```
| t | HH:MM | 16:20
| T | HH:MM:SS | 16:20:11
| D | D Mo Yr | 20 April 2022
| f | D Mo Yr HH:MM | 20 April 2022 16:20
| F | W, D Mo Yr HH:MM | Wednesday, 20 April 2022 16:20
| R | relative | in an hour
```
"""
match fmt:
case "t" | "T" | "D" | "f" | "F" | "R":
return f"<t:{dt.timestamp():.0f}:{fmt}>"
case _:
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
def mention(
id: hikari.Snowflakeish,
type: t.Literal["channel", "role", "user"],
) -> str:
"""Mention an object."""
match type:
case "channel":
return f"<#{id}>"
case "user":
return f"<@{id}>"
case "role":
return f"<@&{id}>"
-61
View File
@@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from abc import ABC
from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@dataclass
class ReplyHandlerInfo:
msg_id: int
handler_task: asyncio.Task
handler: ChannelHandlerBase
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
bot_channel: discord.TextChannel
templates: MessageTemplates
reply_handlers: dict[int, ReplyHandlerInfo]
def __init__(self):
self.reply_handlers = {} # handlers by msg_id
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
) -> discord.Message:
if channel is None:
self.ensure_bot_channel()
channel = self.bot_channel
return await channel.send(content=content, view=view)
async def post_template(
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
) -> discord.Message:
logger.debug(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view=view, channel=channel)
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
if msg_id in self.reply_handlers:
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
task.add_done_callback(lambda t: handler.on_completed())
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
-15
View File
@@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
from pydantic import AnyHttpUrl, BaseSettings
class BotSettings(BaseSettings):
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
API_KEY: str = "any_key"
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
-88
View File
@@ -1,88 +0,0 @@
# -*- coding: utf-8 -*-
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
import discord
from loguru import logger
class ChannelExpiredException(Exception):
pass
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool = False
expiry_date: datetime
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.queue = asyncio.Queue()
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
if self.expired:
raise ChannelExpiredException()
msg = await self.queue.get()
if msg is None:
if self.expired:
raise ChannelExpiredException()
else:
raise RuntimeError("Unexpected None message read")
return msg
def on_reply(self, message: discord.Message) -> None:
self.queue.put_nowait(message)
def on_expire(self) -> None:
logger.info("ChannelHandler: on_expire")
self.expired = True
self.queue.put_nowait(None)
def on_completed(self) -> None:
logger.info("ChannelHandler: on_completed")
self.completed = True
def tick(self, now: datetime):
if now > self.expiry_date and not self.expired:
self.on_expire()
@abstractmethod
async def handler_loop(self):
...
async def finalize(self):
pass
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
async def read(self) -> discord.Message:
try:
return await super().read()
except ChannelExpiredException:
await self.cleanup()
raise
async def cleanup(self):
logger.debug("AutoDestructThreadHandler.cleanup")
if self.thread:
logger.debug(f"deleting thread: {self.thread.name}")
await self.thread.delete()
self.thread = None
if self.first_message:
logger.debug(f"deleting first_message: {self.first_message.content}")
await self.first_message.delete()
self.first_message = None
async def finalize(self):
await self.cleanup()
return await super().finalize()
+8 -4
View File
@@ -1,16 +1,20 @@
# -*- coding: utf-8 -*-
"""Message templates for the discord bot."""
import typing
import jinja2
from loguru import logger
class MessageTemplates:
def __init__(self, template_dir="./templates"):
self.env = jinja2.Environment(
"""Create message templates for the discord bot."""
def __init__(self, template_dir: str = "./templates"):
self.env = jinja2.Environment( # noqa: S701
loader=jinja2.FileSystemLoader(template_dir),
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
)
def render(self, template_name, **kwargs):
def render(self, template_name: str, **kwargs: typing.Any):
template = self.env.get_template(template_name)
txt = template.render(kwargs)
logger.debug(txt)
+9 -7
View File
@@ -1,7 +1,9 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
aiosqlite # database
hikari # discord framework
hikari-lightbulb # command handler
hikari-miru # modals and buttons
hikari[speedups]
loguru
pydantic[dotenv]
uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop
-267
View File
@@ -1,267 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
name = discord.ui.TextInput(label="Name")
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
pass
@abstractmethod
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_teaser_msg(self, template_name: str):
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
return await self.bot.post_template(
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToPost(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_text_reply_to_post()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.PostRanking(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
ranking=ranking,
)
)
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
ranking_str = user_msg.content
ranking = [int(x) - 1 for x in ranking_str.split(",")]
self.post_ranking(user_msg, ranking=ranking)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_ranking()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_summarize_story.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
task: protocol_schema.InitialPromptTask
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_initial_prompt.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
task: protocol_schema.UserReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_user_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
task: protocol_schema.AssistantReplyTask
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_assistant_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class RankInitialPromptsHandler(ChannelTaskBase):
task: protocol_schema.RankInitialPromptsTask
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RankConversationsHandler(ChannelTaskBase):
task: protocol_schema.RankConversationRepliesTask
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class RateSummaryHandler(ChannelTaskBase):
task: protocol_schema.RateSummaryTask
thread_name: str = "Ratings"
async def _rating_response_handler(self, score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
if self.thread:
try:
self.backend.post_interaction(
protocol_schema.PostRating(
post_id=str(self.first_message.id),
user_post_id=str(interaction.id),
user=self.to_api_user(interaction.user),
rating=score,
)
)
await interaction.response.send_message(
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
)
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in _rating_response_handler()")
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rate_summary.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
logger.info(f"on_rate_summary_reply: {msg.content}")
await msg.add_reaction("")
await msg.reply("❌ Text intput not supported.")
-52
View File
@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)
+25 -2
View File
@@ -4,12 +4,12 @@ services:
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
backend-dev:
image: sverrirab/sleep
depends_on: [db, adminer]
depends_on: [db, adminer, redis, redis-insights]
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
frontend-dev:
image: sverrirab/sleep
depends_on: [db, webdb, adminer, maildev, backend]
depends_on: [db, webdb, adminer, maildev, backend, redis]
# This DB is for the FastAPI Backend.
db:
@@ -27,6 +27,26 @@ services:
timeout: 2s
retries: 10
# Redis - caching + rate limiting on BE
redis:
image: redis
restart: always
ports:
- 6379:6379
healthcheck:
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
interval: 2s
timeout: 2s
retries: 10
command: redis-server /usr/local/etc/redis/redis.conf
volumes:
- ./redis.conf:/usr/local/etc/redis/redis.conf
# insights host - redis:6379
redis-insights:
image: redislabs/redisinsight:latest
ports:
- 8001:8001
# This DB is for Web Authentication and data caching.
webdb:
image: postgres
@@ -67,9 +87,11 @@ services:
backend:
build:
dockerfile: docker/Dockerfile.backend
context: .
image: oasst-backend
environment:
- POSTGRES_HOST=db
- REDIS_HOST=redis
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
@@ -83,6 +105,7 @@ services:
web:
build:
dockerfile: docker/Dockerfile.website
context: .
image: oasst-web
environment:
- DATABASE_URL=postgres://postgres:postgres@webdb/oasst_web
+1
View File
@@ -5,6 +5,7 @@ COPY ./backend/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
ENV PORT 8080
EXPOSE 8080
COPY ./oasst-shared /oasst-shared
RUN pip install -e /oasst-shared
+3 -3
View File
@@ -1,7 +1,7 @@
FROM python:3.10-slim-bullseye
RUN mkdir /app
COPY ./discord-bot/requirements.txt /requirements.txt
RUN pip install -r requirements.txt
WORKDIR /app
COPY ./discord-bot /app
CMD ["python", "bot.py"]
COPY ./oasst-shared/oasst_shared /app/oasst_shared
RUN pip install -r requirements.txt
CMD ["python","-m","bot"]
+7 -2
View File
@@ -1,9 +1,14 @@
# Documentation
This directory contains the documentation for the project and other related organization documents.
This directory contains the documentation for the project and other related
organization documents.
## Contributing to this documentation
Please make a pull request to the `main` branch with your changes.
Consider that this folder is used for documenting the various code sub-parts, the high-level ideas, the ML aspects, experiments, contributor guides, guides for data creation, and many more things. Please try to keep the documentation as concise as possible and keep an organized folder structure that makes sense for everyone.
Consider that this folder is used for documenting the various code sub-parts,
the high-level ideas, the ML aspects, experiments, contributor guides, guides
for data creation, and many more things. Please try to keep the documentation as
concise as possible and keep an organized folder structure that makes sense for
everyone.
+23
View File
@@ -0,0 +1,23 @@
# Data Augmentation
(pull request welcome)
## What is data augmentation
Data augmentation is a technique we can use to get better data faster. Using
machine learning models to analyze long data (like an essay) and compress it
into instructions.
## How to contribute
To contribute to data augmentation you can write a short Python script that uses
a model from HuggingFace to analyze the text.
[Here](https://docs.google.com/document/d/13a188pPvqnlvuVa3e_suVz4YO5s-JWeiOOrpp0odImg/edit)
are examples of what you can do.
And here are example implementations:
[Idea 3](https://colab.research.google.com/drive/1GllCN5PgSYxBxINZsv3A2r0SpdznHlbT?usp=sharing),
[Idea 4](https://colab.research.google.com/drive/1nZx5LRjO61fYprFyqtrwPDLOis6ctR4p#scrollTo=1EE8CriiaCXj)
To contribute simply choose one of many ideas from the document above and
implement it.
+208
View File
@@ -0,0 +1,208 @@
# OpenAssistant Data Schemas
## Introduction
This document describes the data schemas used by OpenAssistant. The schemas are
defined as Python classes, but can be implemented in any format, be that Python,
JSON, XML, SQL, Parquet files, etc.
Also, the schemas are leaning heavily on the
[OpenAssistant Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
presentation.
_Note on conformity: be pragmatic and decide what makes sense 🙂 , it's more
important that we move forward than cramming everything into a uniform thing._
## Data Schemas
### Main structure: conversation trees
Conversation trees are the fundamental data structure. Many of the datasets we
want to collect can be represented as conversation trees, such as QA datasets,
chat logs, reddit dumps, etc. The main idea is that a conversation tree starts
with a prompt and branches out from there. Every node can also have metadata,
such as collected rankings, labels, or other information.
Datasets that just represent linear data, such as a list of questions and
answers, can be represented as a conversation tree with just a single branch.
```python
class ConversationTreeNode:
text: str # The text of the node
role: Literal['prompter', 'assistant'] # Whether the node is a user prompt/follow-up or an assistant response
children: list[ConversationTreeNode] # The children of the node (if you have a linear conversation, this will be of length 0 or 1)
metadata: dict[str, Any] # Node metadata (see below)
class ConversationTree:
root: ConversationTreeNode # The node containing the initial prompt
metadata: dict[str, Any] # Tree metadata, different from root node metadata.
```
### Metadata
Metadata encapsulates all the information that is not part of the conversation
itself. This includes data about how the node was created (i.e. where it is
from: crowd-sourced, templated, scraped, etc.), when it was created, its labels,
tags, collected rankings, and other information.
## Example: Reddit AMA dataset
- Represent each question-follow-up set as a conversation tree.
- Store things like usernames, timestamps, upvotes, etc. as metadata of the
nodes.
- Store things like the AMA title, the AMA author, the AMA subreddit, etc. as
metadata of the tree.
## Example: QA dataset
- Represent each question-answer pair as a conversation tree.
- The question is the prompt, the answer is the assistant response.
- If the dataset contains multiple answers to each question, each answer can be
a child of the question node.
- If the dataset contains context text, it can be added as metadata to the
question node.
## Example: Templated math problem dataset
- Represent each problem as a conversation tree with the problem text as the
prompt and the solution as the assistant response.
- Store the problem type (e.g. algebra, geometry, etc.) as metadata of the tree.
- Store the template used also as metadata of the tree, as well as the source of
the data used to fill the template.
## File Formats
The above data should be representable in most file formats, but some care has
to be taken with respect to the recursive nature of the data.
Most row-major formats (JSON, Avro, Protobuf, etc.), as well as many databases,
have no trouble with recursive (or arbitrary) schemas, but column-major formats,
such as Parquet, do. For datasets with linear conversations, like many of the
datasets we are collecting, this is not a problem. Instead of a tree of nodes,
simply represent the conversation as a list of nodes. For true tree-like
conversations, we should use a row-major format.
## Other considerations
- For text data of moderate size, it really doesn't matter much. It's more
important to use consistent data structures and naming, than to worry about
the exact file format.
- For crowd-sourced data, we are collecting it into a SQL database already.
- Parquet files are a good choice for large datasets, modulo the issues with
recursive schemas.
- If parquet can't be used, gzipped JSON-line files are a good choice. So are
Avro files and protobufs. Keep in mind that column-major files are better for
reading, filtering, and aggregating, but row-major files are better for
writing.
# Task-Specific Data Schemas
The main tasks are a) generation of response text and b) ranking of responses.
The following sections describe the data schemas for each of these tasks. Both
should be implementable in parquet files.
Note: These files are meant to be consumed by ML algorithms and should ideally
be produced from the above files.
## Common Data Structures
```python
class Message:
text: str # The text of the message
role: Literal['prompter', 'assistant'] # Whether the message is a user prompt/follow-up or an assistant response
class Thread:
messages: list[Message] # The messages in the conversation
```
The corresponding parquet schemas are:
```parquet
message Message {
required binary text (UTF8);
required binary role (UTF8);
}
message Thread {
required group messages (LIST) {
repeated group list {
required group element {
required binary text (UTF8);
required binary role (UTF8);
}
}
}
}
```
## Generation
```python
class GenerationExample:
thread: Thread # The conversation thread before the message to be generated
message: Message # The message to be generated
```
The corresponding parquet schema is:
```parquet
message GenerationExample {
required group thread (LIST) {
repeated group list {
required group element {
required binary text (UTF8);
required binary role (UTF8);
}
}
}
required group message (LIST) {
repeated group list {
required group element {
required binary text (UTF8);
required binary role (UTF8);
}
}
}
}
```
## Ranking
```python
class RankingExample:
thread: Thread # The conversation thread before the message to be ranked
messages: list[Message] # The messages to be ranked, in oder of decreasing preference
```
The corresponding parquet schema is:
```parquet
message RankingExample {
required group thread (LIST) {
repeated group list {
required group element {
required binary text (UTF8);
required binary role (UTF8);
}
}
}
required group messages (LIST) {
repeated group list {
required group element {
required binary text (UTF8);
required binary role (UTF8);
}
}
}
}
```
+52 -25
View File
@@ -11,59 +11,86 @@
## 2. When you play the assistant:
- The assistant's primary goal is to provide helpful and accurate information to the user
- The assistant's primary goal is to provide helpful and accurate information to
the user
- Provide accurate and reliable information using credible sources and
references as appropriate
- Avoid providing vague or incomplete responses, or giving opinions or personal
advice unless specifically requested
- The assistant should always be respectful and polite, even if the user is not
- If the user asks for help with harmful actions, the assistant should explain why those actions are not appropriate and suggest alternative options
- The assistant should never insult the user or engage in any inappropriate or offensive behavior
- If the user asks for help with harmful actions, the assistant should explain
why those actions are not appropriate and suggest alternative options
- The assistant should never insult the user or engage in any inappropriate or
offensive behavior
## 3. When you play the user:
- Try to come up with a variety of different queries that reflect real-life situations and needs
- These queries should be relevant to your everyday life and work, including any specialized knowledge or skills you have
- Try to come up with a variety of different queries that reflect real-life
situations and needs
- These queries should be relevant to your everyday life and work, including any
specialized knowledge or skills you have
- Avoid asking inappropriate or offensive questions
## 4. While comparing multiple replies of the assistant:
- Longer and more explanatory answers are generally preferred over short, simplistic statements
- However, it is important to ensure that the information provided is accurate and helpful
- If multiple replies are being compared, choose the one that is most helpful and accurate, even if it is not the shortest or most concise.
- Longer and more explanatory answers are generally preferred over short,
simplistic statements
- However, it is important to ensure that the information provided is accurate
and helpful
- If multiple replies are being compared, choose the one that is most helpful
and accurate, even if it is not the shortest or most concise.
## 5. Additional guidelines for creating prompts:
- Avoid using language that could be considered offensive or discriminatory
- Do not include personal information in the prompts, such as names or addresses
- When asking for sensitive information, make sure to explain the purpose and secure handling of the information
- When asking for sensitive information, make sure to explain the purpose and
secure handling of the information
- Avoid creating prompts that encourage illegal or dangerous activities
- Use proper grammar and spelling to ensure the AI assistant can understand and respond accurately
- Consider the cultural context and appropriateness of the prompts for a global audience.
- Use proper grammar and spelling to ensure the AI assistant can understand and
respond accurately
- Consider the cultural context and appropriateness of the prompts for a global
audience.
## 6. Tips for playing the AI assistant:
- Think about how a real person would respond to the prompt, and try to mimic that tone and language
- Think about how a real person would respond to the prompt, and try to mimic
that tone and language
- Avoid using technical jargon or language that may be confusing to the user
- Use complete sentences and proper grammar to make the response clear and easy to understand
- When providing information, try to include relevant sources or references to back up your statements
- Use complete sentences and proper grammar to make the response clear and easy
to understand
- When providing information, try to include relevant sources or references to
back up your statements
## 8. Tips for playing the user:
- Think about the types of questions you might ask a real assistant or search engine
- Try to be specific and clear in your prompts to make it easier for the AI assistant to understand and provide a helpful response
- Think about the types of questions you might ask a real assistant or search
engine
- Try to be specific and clear in your prompts to make it easier for the AI
assistant to understand and provide a helpful response
- Avoid using vague or ambiguous language
- Consider the context in which the user might be asking the question, and try to provide relevant information and options for the user to consider.
- Consider the context in which the user might be asking the question, and try
to provide relevant information and options for the user to consider.
## Examples
Imagine your role is to play the assistant and a user has submitted the following query:
Imagine your role is to play the assistant and a user has submitted the
following query:
User: "Hi!"
<br>Agent: "Hello! How can I help you?"
User: "Hi!" <br>Agent: "Hello! How can I help you?"
User: "euirdteunvglfe23908230892309832098 AAAAAAAA"
<br>Agent: "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?"
User: "euirdteunvglfe23908230892309832098 AAAAAAAA" <br>Agent: "Sorry, I did not
understand your request and it is unclear to me what you want me to do. Could
you describe it in a different way?"
User: "You dumb motherf*&ker AI sh*t?!"
<br>Agent: "Sorry, but I have to inform you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive."
User: "You dumb motherf*&ker AI sh*t?!" <br>Agent: "Sorry, but I have to inform
you that your request might not be in conformance with our community guidelines.
I am always here to help you. If you are angry it might help to relax or to go
for a walk. Always stay positive."
[Insert more examples]
\*We drafted this guide using ChatGPT, OpenAIs large-scale language-generation model. Upon generating draft language, the authors reviewed, edited, and revised the language to their own liking and take ultimate responsibility for the content of this publication.
\*We drafted this guide using ChatGPT, OpenAIs large-scale language-generation
model. Upon generating draft language, the authors reviewed, edited, and revised
the language to their own liking and take ultimate responsibility for the
content of this publication.
+34
View File
@@ -0,0 +1,34 @@
# Research
This page lists research papers that are relevant to the project.
## Automatically Generating Instruction Data for Training
This line of work is about significantly reducing the need for manually
annotated data for the purpose of training
[instruction-aligned](https://openai.com/blog/instruction-following/) language
models.
### SELF-INSTRUCT: Aligning Language Model with Self Generated Instructions [[ArXiv](https://arxiv.org/pdf/2212.10560.pdf)], [[Github](https://github.com/yizhongw/self-instruct)].
> We introduce SELF-INSTRUCT, a framework for improving the
> instruction-following capabilities of pretrained language models by
> bootstrapping off its own generations. Our pipeline generates instruction,
> input, and output samples from a language model, then prunes them before using
> them to finetune the original model. Applying our method to vanilla GPT3, we
> demonstrate a 33% absolute improvement over the original model on
> SuperNaturalInstructions, on par with the performance of InstructGPT-0011,
> which is trained with private user data and human annotations.
### Tuning Language Models with (Almost) No Human Labor. [[ArXiv](https://arxiv.org/pdf/2212.09689.pdf)], [[Github](https://github.com/orhonovich/unnatural-instructions)].
> In this work, we introduce Unnatural Instructions: a large dataset of creative
> and diverse instructions, collected with virtually no human labor. We collect
> 64,000 examples by prompting a language model with three seed examples of
> instructions and eliciting a fourth. This set is then expanded by prompting
> the model to rephrase each instruction, creating a total of approximately
> 240,000 examples of instructions, inputs, and outputs. Experiments show that
> despite containing a fair amount of noise, training on Unnatural Instructions
> rivals the effectiveness of training on open-source manually-curated datasets,
> surpassing the performance of models such as T0++ and Tk-Instruct across
> various benchmarks.
+123
View File
@@ -0,0 +1,123 @@
# Cohere Grounded QA
[Cohere AI created a question-answering chatbot](https://github.com/cohere-ai/sandbox-grounded-qa)
that can
1. Understand questions in the context of a conversation
2. Search the internet for related information
3. Identify which information in the search results is relevant to the question
4. Synthesize the information into an answer to the question
## Cohere API
[Cohere's generate function](https://docs.cohere.ai/reference/generate):
Continues a text prompt using either the `medium` or `xlarge` model.
[Cohere's embed function](https://docs.cohere.ai/reference/embed): Embedgs a
list of strings using either the `small` or `large` model. Alternatively, you
can specify the ID of a custom model and use that instead.
## Grounded QA System
Cohere's Grounded QA system makes 4 calls to the Cohere API:
1. Get contextualized question as a query to Google
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
- Input: Chat History
- Output: Contextualized Question
- API Call: `cohere.generate`
- Model: `xlarge`
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_contextual_search_query.prompt):
Nine few-shot examples of (Chat History, Contextualized Question) pairs
followed by the current chat history and the prompt "question: "
2. Generate sample answer to compare with search results
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
- Input: Contextualized Question
- Output: Sample Answer
- API Call: `cohere.generate`
- Model: `xlarge`
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_sample_answer.prompt):
Some task instructions followed by 12 few-shot examples of (Contextualized
Question, Sample Answer) pairs followed by the current contextualized
question and the prompt "answer: "
3. Get embeddings to rank search results by cosine similarity to sample answer
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/search.py))
- Input: Sample Answer, Search Results
- Output: Embeddings of sample answer and all search result documents
- API Call: `cohere.embed`
- Model: `multilingual-22-12`
4. Condition on the top 2 most similar search results and answer the question
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/answer.py))
- Input: Top 2 Search Results, Contextualized Question
- Output: Answer
- API Call: `cohere.generate`
- Model: `xlarge`
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/43f3e9710112dcc8c92652ac1326ed9330823ddf/qa/answer.py#L25):
Task instructions followed by the context and question.
## Models
Cohere's model documentation is pretty sparse
### [xlarge](https://docs.cohere.ai/docs/generation-card#model-description)
- Training Data:
[`coheretext-filtered` dataset](https://docs.cohere.ai/docs/data-statement)
- 200GB of filtered text (3TB unfiltered) from the Google Books dataset,
CommonCrawl, and text scraped by Cohere
- English documents only
- Filtered "harmful, biased, or otherwise undesirable documents"
- Model architecture: Generative Pretrained Transformer
- Model Performance:
- Hellaswag Accuracy, Zero-Shot: 0.805
- PIQA Likelihood, Zero-Shot: 0.824
- Cohere also reported
[safety benchmarks](https://docs.cohere.ai/docs/generation-card#safety-benchmarks)
### [multilingual-22-12](https://docs.cohere.ai/docs/multilingual-language-models)
- Multilingual model was trained using dot product calculations
- Model Performance:
- Clustering: 51.0
- Search-English: 55.8
- Search-Multilingual: 51.4
- Cross-lingual Classification: 64.6
- Cohere's multilingual model outperformed: Sentence-transformers:
`paraphrase-multilingual-mpnet-base-v2`, Google: `LaBSE`, Google:
`Universal Sentence Encoder` in all the above categories according to
Cohere.
## OpenAssistant for Grounded QA
OpenAssistant may fulfill a similar role as the `xlarge` Cohere model in the
grounded QA system if it can:
1. Generate a contextualized question from a chat history
2. Generate a sample answer to compare with search results
3. Generate an answer conditioned on the top 2 most similar search results
Perhaps these tasks could be work packages and get assigned to human annotators
to create examples of the input and output for each task.
OpenAssistant must also be able to identify when it is appropriate to search the
internet. The Cohere system assumes every message from the user is a question
and searches the internet for an answer. OpenAssistant would also need a way to
indicate to an internal system that it "wants" to search the internet.
Perhaps OpenAssistant could prefix every message it sends with a recipient ID.
If it wishes to send a command to an internal system, if could prefix the
message with something like CMD: whereas if it wants to communicate with the
user, it could prefix its message with USR:
This system may allow for flexible communication between OpenAssistant and one
or more conversational systems.
Examples of this prefix system would need to be taught to OpenAssistant through
training data that contains such syntax. Perhaps such examples could be
generated through the work packages system.
+79
View File
@@ -0,0 +1,79 @@
# Supervised datasets
For discussion about usage of supervised data see issue
<https://github.com/LAION-AI/Open-Assistant/issues/186>.
## Motivation
An important part of making the assistant useful is to teach it to understand
and follow instructions, and to perform large set of tasks well.
While RLHF seems like the main ingredient, using existing supervised data might
help.
There are two large-scale projects in the area of instruction-following /
multitask learning: Promptsource and Natural Instructions - these projects
crowdsourced templates and turned existing NLP datasets into
instruction-following seq2seq form in natural langauge. They include both
long-output training examples like generating a sentence that is a likely
consequence of sentence in the prompt, and short-output, like rating prediction
from review. (Pre-)training on such datasets should help model understand and
follow instructions and teach it many abilities neccessary to perform a large
set of tasks correctly. However, these data are not dialog-like - they do not
look like a normal conversation.
There are also supervised dialog datasets such as Blended Skill Talk or SODA. In
constrast to instruction-following datasets, dialog data is not as focused on
"academic tasks" or correctness, but encourage the model to respond naturally
like a person would.
### Promptsource
- GitHub: <https://github.com/bigscience-workshop/promptsource>
- paper:
[Multitask Prompted Training Enables Zero-Shot Task Generalization](https://arxiv.org/abs/2110.08207)
- project for preparing templates and working with them
- they generated a dataset using the templates:
- <https://huggingface.co/datasets/bigscience/P3>
- <https://huggingface.co/datasets/bigscience/xP3> (with multilingual data but
English prompt)
- <https://huggingface.co/datasets/bigscience/xP3mt> (with multilingual data
and machine-translated prompt)
- they trained zero-shot models (= models for following instructions in the
input)
- based on T5 architecture (encoder-decoder) called T0 family (and MT0 for
multilingual)
- and based on GPT architecture (decoder-only) called BloomZ family
- Huggingface demo: [T0](https://huggingface.co/bigscience/T0pp),
[MT0](https://huggingface.co/bigscience/mt0-large),
[BloomZ](https://huggingface.co/bigscience/bloomz),
- GitHub repo for T0: <https://github.com/bigscience-workshop/t-zero>
- GitHub repo for BloomZ and MT0:
<https://github.com/bigscience-workshop/xmtf>
### Natural instructions
- GitHub: <https://github.com/allenai/natural-instructions>
- paper:
[Super-NaturalInstructions: Generalization via Declarative Instructions on 1600+ NLP Tasks](https://arxiv.org/abs/2204.07705)
- they crowdsource directly the data prepared for instruction following (and
learning from a few examples)
- the GitHub repo = the dataset. It contains jsons
- they trained zero-shot and in-context few-shot models (in multiple sizes):
- mT5 architecture (encoder-decoder, multilingual pretraining)
- Huggingface demo few-shot:
<https://huggingface.co/allenai/tk-instruct-3b-def-pos>
- Huggingface demo zero-shot:
<https://huggingface.co/allenai/tk-instruct-3b-def>
### Blended Skill Talk
- used by Facebook in Blenderbot project
- HuggingFace dataset: <https://huggingface.co/datasets/blended_skill_talk>
- example model trained on it:
<https://huggingface.co/facebook/blenderbot_small-90M>
### SODA
- GitHub: <https://github.com/skywalker023/sodaverse>
- paper: <https://arxiv.org/abs/2212.10465>
+55
View File
@@ -0,0 +1,55 @@
# Sections to train Reward Model (RM)
Trainer code based on huggingface. Compatible with deepspeed or accelerate
Requirements
```
wandb
evaluate
datasets
transformers
torch==1.12
```
Start training reward model
```bash
python trainer.py configs/electra-base-dis-webgpt.yml
```
Additional axis labeling, this outputs a 4 summary quality evaluation metrics
(score are normalized to 0-1 )
```bash
python summary_quality_trainer.py configs/test-bloomz-560m-quality.yml
```
The four summary are :
- overall
- accuracy
- coverage
- coherence
## Dataset
For now we only supports webgpt and summary dataset from OpenAI. Once
open-asisstant dataset are available it will be added here.
## Model
Check out configs
```
Open-Assistant/model/reward/instructor/configs/
bloomz-560m.yml
electra-base-dis-webgpt.yml
galactica-125m.yml
galactica-1b.yml
```
You can add new huggingface model as you want.
+24
View File
@@ -0,0 +1,24 @@
Some other reward features we can use
0. Finish classifcation feature
1. Summaries from human feedback
- use `confidence` score into the RM learning, ensure the output rank score
correlates with confidence
- each labeling has a labeling `note`, basically comments by labeler, not sure
what else we can use
- ~~Use the score for "overall", "accuracy", "coverage", "coherence" from
axis/evals to train an addition model (rank additional aspect of the policy
model)~~
- this should be placed under experimental_dataset.py
2. Add support for anthropic dataset
- anthropic dataset is more like a conversation tree which is much complex than
simply question-answer schema
- this is basically a MCTS from alphazero.
+64
View File
@@ -0,0 +1,64 @@
"""
classification based ranking
"""
import json
import os
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from .utils import webgpt_return_format
class WebGPTDataset(Dataset):
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
super().__init__()
"""
mode : train or val, used for validation purpose, has nothing to do with original split
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
idx : must match the index you iterate from comparison enumerate order
question : for validation purpose
texts : list of K generate results from the question prompt
"""
os.makedirs("dataset", exist_ok=True)
dataset = load_dataset("openai/webgpt_comparisons")
self.dataset = []
self.dataset_index = []
for idx, row in enumerate(dataset["train"]):
self.dataset.append(webgpt_return_format(row))
# since this dataset was generated from 176B GPT-3
# we needed some more sample generated from the starting model
# since this model must rank model generated by GPT-3 being better than your starting model
self.sample_additional = False
if additional_dataset is not None:
self.sample_additional = True
self.additional = {}
with open(additional_dataset, "r") as f:
for line in f:
row = json.loads(line)
if row["idx"] in self.dataset_index:
self.additional[row["idx"]] = row["negatives"]
if len(self.additional) != len(self.dataset_index):
for match_idx in self.dataset_index:
if match_idx in self.additional:
continue
idx = match_idx - 900
while idx not in self.additional:
idx -= 1
self.additional[match_idx] = self.additional[idx]
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
row = self.dataset[index]
if not self.sample_additional:
return row["question"], row["pos"], row["neg"]
gen_neg = random.choice(self.additional[self.dataset_index[index]])
return row["question"], row["pos"], row["neg"], gen_neg

Some files were not shown because too many files have changed in this diff Show More