mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into dark-mode-implementation
This commit is contained in:
@@ -12,7 +12,7 @@ on:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-frontend:
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16.x
|
||||
cache: 'npm'
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- run: npm ci
|
||||
- run: npm run build
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
name: Test API Contract
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
test-contract:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- run: cd oasst-shared && pip install -e .
|
||||
|
||||
- run: cd backend && pip install -r requirements.txt
|
||||
|
||||
- run: cd discord-bot && pip install -r requirements.txt
|
||||
|
||||
- run: cd discord-bot && pip install -r requirements.dev.txt
|
||||
|
||||
- run: ./scripts/backend-development/start-mock-server.sh
|
||||
|
||||
# runs the contract tests. currently the api client is
|
||||
# found in the discord bot code, but this should be updated
|
||||
# once the client moves into oasst-shared.
|
||||
- name: Run contract tests
|
||||
run: ./scripts/discord-bot-development/test.sh
|
||||
|
||||
- run: ./scripts/backend-development/stop-mock-server.sh
|
||||
@@ -5,3 +5,6 @@
|
||||
*.egg-info
|
||||
__pycache__
|
||||
.DS_Store
|
||||
|
||||
# Generated files
|
||||
backend/oasst-openapi.json
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
exclude: "build|stubs|^bot/templates/"
|
||||
exclude: "build|stubs|^bot/templates/|^notebooks/.*\\.ipynb$"
|
||||
|
||||
default_language_version:
|
||||
python: python3
|
||||
@@ -50,14 +50,15 @@ repos:
|
||||
rev: v2.7.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
args: ["--write"]
|
||||
args: ["--prose-wrap=always", "--write"]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: next-lint-website
|
||||
name: Lint website
|
||||
files: ^website/
|
||||
exclude: ^website/node_modules/
|
||||
types_or: [javascript, jsx, ts, tsx]
|
||||
language: system
|
||||
language: node
|
||||
pass_filenames: false
|
||||
entry: bash -c 'cd website && npm ci && npm run lint'
|
||||
entry: website/next-lint.js
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
# Open-Assistant
|
||||
|
||||
Open Assistant is a project meant to give everyone access to a great chat based large language model.
|
||||
Open Assistant is a project meant to give everyone access to a great chat based
|
||||
large language model.
|
||||
|
||||
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
|
||||
We believe that by doing this we will create a revolution in innovation in
|
||||
language. In the same way that stable-diffusion helped the world make art and
|
||||
images in new ways we hope Open Assistant can help improve the world by
|
||||
improving language itself.
|
||||
|
||||
## Do you want to try it out?
|
||||
|
||||
If you are interested in taking a look at the current state of the project, You can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
If you are interested in taking a look at the current state of the project, You
|
||||
can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
website, backend, and associated dependent services.
|
||||
|
||||
To start the demo, Run this in the root directory of the repository:
|
||||
@@ -15,37 +20,67 @@ To start the demo, Run this in the root directory of the repository:
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and interact with the website.
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
|
||||
interact with the website.
|
||||
|
||||
**Note:** When logging in via email, navigate to `http://localhost:1080` to get the magic email login link.
|
||||
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
|
||||
the magic email login link.
|
||||
|
||||
## The Plan
|
||||
|
||||
We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
|
||||
We want to get to an initial MVP as fast as possible, by following the 3-steps
|
||||
outlined in the InstructGPT paper.
|
||||
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples (prompt + response), goal >50k. We design a crowdsourced process to collect and reviewed prompts. We do not want to train on flooding/toxic/spam/junk/personal information data. We will have a leaderboard to motivate the community that shows progress and the most active users. Swag will be given to the top-contributors.
|
||||
2. For each of the collected prompts we will sample multiple completions. Completions of one prompt will then be shown randomly to users to rank them from best to worst. Again this should happen crowd-sourced, e.g. we need to deal with unreliable potentially malicious users. At least multiple votes by independent users have to be collected to measure the overall agreement. The gathered ranking-data will be used to train a reward model.
|
||||
3. Now follows the RLHF training phase based on the prompts and the reward model.
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples
|
||||
(prompt + response), goal >50k. We design a crowdsourced process to collect
|
||||
and reviewed prompts. We do not want to train on
|
||||
flooding/toxic/spam/junk/personal information data. We will have a
|
||||
leaderboard to motivate the community that shows progress and the most active
|
||||
users. Swag will be given to the top-contributors.
|
||||
2. For each of the collected prompts we will sample multiple completions.
|
||||
Completions of one prompt will then be shown randomly to users to rank them
|
||||
from best to worst. Again this should happen crowd-sourced, e.g. we need to
|
||||
deal with unreliable potentially malicious users. At least multiple votes by
|
||||
independent users have to be collected to measure the overall agreement. The
|
||||
gathered ranking-data will be used to train a reward model.
|
||||
3. Now follows the RLHF training phase based on the prompts and the reward
|
||||
model.
|
||||
|
||||
We can then take the resulting model and continue with completion sampling step 2 for a next iteration.
|
||||
We can then take the resulting model and continue with completion sampling step
|
||||
2 for a next iteration.
|
||||
|
||||
## The Vision
|
||||
|
||||
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
|
||||
We are not going to stop at replicating ChatGPT. We want to build the assistant
|
||||
of the future, able to not only write email and cover letters, but do meaningful
|
||||
work, use APIs, dynamically research information, and much more, with the
|
||||
ability to be personalized and extended by anyone. And we want to do this in a
|
||||
way that is open and accessible, which means we must not only build a great
|
||||
assistant, but also make it small and efficient enough to run on consumer
|
||||
hardware.
|
||||
|
||||
### Slide Decks
|
||||
|
||||
[Vision & Roadmap](https://docs.google.com/presentation/d/1n7IrAOVOqwdYgiYrXc8Sj0He8krn5MVZO_iLkCjTtu0/edit?usp=sharing)
|
||||
|
||||
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
|
||||
|
||||
## How can you help?
|
||||
|
||||
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
|
||||
All open source projects begins with people like you. Open source is the belief
|
||||
that if we collaborate we can together gift our knowledge and technology to the
|
||||
world for the benefit of humanity.
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
|
||||
this is for work coordination.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord)
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
|
||||
a dedicated channel and is more public.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
|
||||
dedicated, but not as active, channel.
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
@@ -53,30 +88,41 @@ All open source projects begins with people like you. Open source is the belief
|
||||
|
||||
We have a growing task list
|
||||
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
|
||||
that appeals to you and make a comment that you'd like to work on it. Include
|
||||
in your comment a brief description of how you'll solve the problem and if
|
||||
there are any open questions you want to discuss. Once a project coordinator
|
||||
has assigned the issue to you, start working on it.
|
||||
that appeals to you and make a comment that you'd like to work on it. Include in
|
||||
your comment a brief description of how you'll solve the problem and if there
|
||||
are any open questions you want to discuss. Once a project coordinator has
|
||||
assigned the issue to you, start working on it.
|
||||
|
||||
If the issue is currently unclear but you are interested, please post in
|
||||
Discord and someone can help clarify the issue with more detail.
|
||||
If the issue is currently unclear but you are interested, please post in Discord
|
||||
and someone can help clarify the issue with more detail.
|
||||
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
|
||||
the system architecture, and other documentation.
|
||||
|
||||
### Submitting Work
|
||||
|
||||
We're all working on different parts of Open Assistant together. To make
|
||||
contributions smoothly we recommend the following:
|
||||
|
||||
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
|
||||
and clone it to your local machine. (Read more
|
||||
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
|
||||
1. Before working on any changes, try to
|
||||
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
|
||||
to keep it up-to-date with the upstream repository.
|
||||
1. Work on a small focused change that only touches on a few files.
|
||||
1. Run `pre-commit` and make sure all files have formatting fixed. This
|
||||
simplifies life for reviewers.
|
||||
1. Package up a small bit of work that solves part of the problem into a Pull
|
||||
Request and send it out for review
|
||||
1. Package up a small bit of work that solves part of the problem
|
||||
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
|
||||
and
|
||||
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase)
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-rebase).
|
||||
1. If rebase feels too painful, merge as suggested
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge)
|
||||
[in these instructions](https://timwise.co.uk/2019/10/14/merge-vs-rebase/#should-you-merge).
|
||||
1. Once you've resolved any conflicts, finish the review and merge into `main`.
|
||||
1. Merge in your change and move onto a new issue or the second step of your
|
||||
current issue.
|
||||
@@ -95,20 +141,27 @@ addressed now, or filing an issue to handle it later.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
Work is organized in the [project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
Work is organized in the
|
||||
[project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs. Meaning we'd be happy for anyone to do these tasks.**
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs.
|
||||
Meaning we'd be happy for anyone to do these tasks.**
|
||||
|
||||
If you want to work on something, assign yourself to it or write a comment that you want to work on it and what you plan to do.
|
||||
If you want to work on something, assign yourself to it or write a comment that
|
||||
you want to work on it and what you plan to do.
|
||||
|
||||
- To get started with development, if you want to work on the backend, have a look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at `scripts/frontend-development/README.md` to make a backend available.
|
||||
- To get started with development, if you want to work on the backend, have a
|
||||
look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at
|
||||
`scripts/frontend-development/README.md` to make a backend available.
|
||||
|
||||
There is also a minimal implementation of a frontend in the `text-frontend` folder.
|
||||
There is also a minimal implementation of a frontend in the `text-frontend`
|
||||
folder.
|
||||
|
||||
We are using Python 3.10 for the backend.
|
||||
|
||||
Check out the [High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
Check out the
|
||||
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
|
||||
### Website
|
||||
|
||||
@@ -116,10 +169,25 @@ The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit hooks.
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
|
||||
Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
|
||||
### Problems and Solutions
|
||||
|
||||
- **I am on Ubuntu and getting
|
||||
`ERROR: The Compose file is invalid because:Service backend has neither an image nor a build context specified. At least one must be provided.`**
|
||||
|
||||
Make sure you have an up-to-date version of docker installed, and also install
|
||||
`docker-compose-plugin`. See
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/issues/208) for more
|
||||
details.
|
||||
|
||||
+5
-2
@@ -2,7 +2,9 @@
|
||||
|
||||
## REST Server Configuration
|
||||
|
||||
Please either use environment variables or create a `.env` file in the backend root directory (in which this readme file is located) to specify the `DATABASE_URI`.
|
||||
Please either use environment variables or create a `.env` file in the backend
|
||||
root directory (in which this readme file is located) to specify the
|
||||
`DATABASE_URI`.
|
||||
|
||||
Example contents of a `.env` file for the backend:
|
||||
|
||||
@@ -14,4 +16,5 @@ BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://local
|
||||
|
||||
## Running the REST Server locally for development
|
||||
|
||||
Have a look into the main `README.md` file for more information on how to set up the backend for development.
|
||||
Have a look into the main `README.md` file for more information on how to set up
|
||||
the backend for development.
|
||||
|
||||
+339
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""name changes: person->user, post->message, work_package->task
|
||||
|
||||
Revision ID: abb47e9d145a
|
||||
Revises: 73ce3675c1f5
|
||||
Create Date: 2022-12-30 20:54:49.880568
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abb47e9d145a"
|
||||
down_revision = "73ce3675c1f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# clear DB
|
||||
op.execute("DELETE FROM journal;")
|
||||
op.execute("DELETE FROM work_package;")
|
||||
op.execute("DELETE FROM post_reaction;")
|
||||
op.execute("DELETE FROM post;")
|
||||
op.execute("DELETE FROM person_stats;")
|
||||
op.execute("DELETE FROM person;")
|
||||
op.execute("DELETE FROM text_labels;")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("auth_method", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_user_username", "user", ["api_client_id", "username", "auth_method"], unique=True)
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("parent_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("role", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("lang", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_message_frontend_message_id", "message", ["api_client_id", "frontend_message_id"], unique=True)
|
||||
op.create_index(op.f("ix_message_message_tree_id"), "message", ["message_tree_id"], unique=False)
|
||||
op.create_index(op.f("ix_message_task_id"), "message", ["task_id"], unique=False)
|
||||
op.create_index(op.f("ix_message_user_id"), "message", ["user_id"], unique=False)
|
||||
op.create_table(
|
||||
"task",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("expiry_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("ack", sa.Boolean(), nullable=True),
|
||||
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("parent_message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_task_user_id"), "task", ["user_id"], unique=False)
|
||||
op.create_table(
|
||||
"user_stats",
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("leader_score", sa.Integer(), nullable=False),
|
||||
sa.Column("reactions", sa.Integer(), nullable=False),
|
||||
sa.Column("messages", sa.Integer(), nullable=False),
|
||||
sa.Column("upvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("downvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("task_reward", sa.Integer(), nullable=False),
|
||||
sa.Column("compare_wins", sa.Integer(), nullable=False),
|
||||
sa.Column("compare_losses", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"message_reaction",
|
||||
sa.Column("task_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["task_id"],
|
||||
["task.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("task_id", "user_id"),
|
||||
)
|
||||
|
||||
op.drop_constraint("text_labels_post_id_fkey", "text_labels", type_="foreignkey")
|
||||
op.drop_constraint("journal_post_id_fkey", "journal", type_="foreignkey")
|
||||
op.drop_constraint("journal_person_id_fkey", "journal", type_="foreignkey")
|
||||
|
||||
op.drop_table("post_reaction")
|
||||
|
||||
op.drop_index("ix_post_frontend_post_id", table_name="post")
|
||||
op.drop_index("ix_post_person_id", table_name="post")
|
||||
op.drop_index("ix_post_thread_id", table_name="post")
|
||||
op.drop_index("ix_post_workpackage_id", table_name="post")
|
||||
op.drop_table("post")
|
||||
|
||||
op.drop_index("ix_work_package_person_id", table_name="work_package")
|
||||
op.drop_table("work_package")
|
||||
op.drop_table("person_stats")
|
||||
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.drop_table("person")
|
||||
|
||||
op.add_column("journal", sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.add_column("journal", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.drop_index("ix_journal_person_id", table_name="journal")
|
||||
op.create_index(op.f("ix_journal_user_id"), "journal", ["user_id"], unique=False)
|
||||
|
||||
op.create_foreign_key(None, "journal", "user", ["user_id"], ["id"])
|
||||
op.create_foreign_key(None, "journal", "message", ["message_id"], ["id"])
|
||||
op.drop_column("journal", "person_id")
|
||||
op.drop_column("journal", "post_id")
|
||||
op.add_column("text_labels", sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True))
|
||||
op.create_foreign_key(None, "text_labels", "message", ["message_id"], ["id"])
|
||||
op.drop_column("text_labels", "post_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# clear DB
|
||||
op.execute("DELETE FROM journal;")
|
||||
op.execute("DELETE FROM message_reaction;")
|
||||
op.execute("DELETE FROM task;")
|
||||
op.execute("DELETE FROM message;")
|
||||
op.execute("DELETE FROM user_stats;")
|
||||
op.execute('DELETE FROM "user";')
|
||||
op.execute("DELETE FROM text_labels;")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("text_labels", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint("text_labels_message_id_fkey", "text_labels", type_="foreignkey")
|
||||
|
||||
op.drop_column("text_labels", "message_id")
|
||||
op.add_column("journal", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.add_column("journal", sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint("journal_message_id_fkey", "journal", type_="foreignkey")
|
||||
op.drop_constraint("journal_user_id_fkey", "journal", type_="foreignkey")
|
||||
|
||||
op.drop_index(op.f("ix_journal_user_id"), table_name="journal")
|
||||
op.create_index("ix_journal_person_id", "journal", ["person_id"], unique=False)
|
||||
op.drop_column("journal", "message_id")
|
||||
op.drop_column("journal", "user_id")
|
||||
|
||||
op.create_table(
|
||||
"person",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column("username", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.Column("display_name", sa.VARCHAR(length=256), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("auth_method", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="person_api_client_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="person_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"person_stats",
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("leader_score", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"modified_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("reactions", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("posts", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("upvotes", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("downvotes", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("work_reward", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("compare_wins", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("compare_losses", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="person_stats_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("person_id", name="person_stats_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"work_package",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("expiry_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("done", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||
sa.Column("ack", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
sa.Column("frontend_ref_post_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("parent_post_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("collective", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="work_package_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="work_package_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="work_package_pkey"),
|
||||
)
|
||||
op.create_index("ix_work_package_person_id", "work_package", ["person_id"], unique=False)
|
||||
op.create_table(
|
||||
"post",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column("parent_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("workpackage_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("role", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.Column("frontend_post_id", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||
sa.Column("depth", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("children_count", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="post_pkey"),
|
||||
)
|
||||
op.create_index("ix_post_workpackage_id", "post", ["workpackage_id"], unique=False)
|
||||
op.create_index("ix_post_thread_id", "post", ["thread_id"], unique=False)
|
||||
op.create_index("ix_post_person_id", "post", ["person_id"], unique=False)
|
||||
op.create_index("ix_post_frontend_post_id", "post", ["api_client_id", "frontend_post_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post_reaction",
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("work_package_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_reaction_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_reaction_person_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["work_package_id"], ["work_package.id"], name="post_reaction_work_package_id_fkey"),
|
||||
)
|
||||
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=False)
|
||||
op.create_foreign_key("text_labels_post_id_fkey", "text_labels", "post", ["post_id"], ["id"])
|
||||
op.create_foreign_key("journal_person_id_fkey", "journal", "person", ["person_id"], ["id"])
|
||||
op.create_foreign_key("journal_post_id_fkey", "journal", "post", ["post_id"], ["id"])
|
||||
|
||||
op.drop_table("message_reaction")
|
||||
op.drop_table("user_stats")
|
||||
op.drop_index(op.f("ix_task_user_id"), table_name="task")
|
||||
op.drop_table("task")
|
||||
op.drop_index(op.f("ix_message_user_id"), table_name="message")
|
||||
op.drop_index(op.f("ix_message_task_id"), table_name="message")
|
||||
op.drop_index(op.f("ix_message_message_tree_id"), table_name="message")
|
||||
op.drop_index("ix_message_frontend_message_id", table_name="message")
|
||||
op.drop_table("message")
|
||||
op.drop_index("ix_user_username", table_name="user")
|
||||
op.drop_table("user")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add deleted field to post
|
||||
|
||||
Revision ID: 8d269bc4fdbd
|
||||
Revises: abb47e9d145a
|
||||
Create Date: 2022-12-31 04:38:41.799206
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8d269bc4fdbd"
|
||||
down_revision = "abb47e9d145a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message", "deleted")
|
||||
# ### end Alembic commands ###
|
||||
+91
-60
@@ -67,10 +67,10 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
class DummyPost(pydantic.BaseModel):
|
||||
task_post_id: str
|
||||
user_post_id: str
|
||||
parent_post_id: Optional[str]
|
||||
class DummyMessage(pydantic.BaseModel):
|
||||
task_message_id: str
|
||||
user_message_id: str
|
||||
parent_message_id: Optional[str]
|
||||
text: str
|
||||
role: str
|
||||
|
||||
@@ -81,96 +81,97 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
dummy_posts = [
|
||||
DummyPost(
|
||||
task_post_id="de111fa8",
|
||||
user_post_id="6f1d0711",
|
||||
parent_post_id=None,
|
||||
dummy_messages = [
|
||||
DummyMessage(
|
||||
task_message_id="de111fa8",
|
||||
user_message_id="6f1d0711",
|
||||
parent_message_id=None,
|
||||
text="Hi!",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="74c381d4",
|
||||
user_post_id="4a24530b",
|
||||
parent_post_id="6f1d0711",
|
||||
DummyMessage(
|
||||
task_message_id="74c381d4",
|
||||
user_message_id="4a24530b",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hello! How can I help you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="3d5dc440",
|
||||
user_post_id="a8c01c04",
|
||||
parent_post_id="4a24530b",
|
||||
DummyMessage(
|
||||
task_message_id="3d5dc440",
|
||||
user_message_id="a8c01c04",
|
||||
parent_message_id="4a24530b",
|
||||
text="Do you have a recipe for potato soup?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="643716c1",
|
||||
user_post_id="f43a93b7",
|
||||
parent_post_id="4a24530b",
|
||||
DummyMessage(
|
||||
task_message_id="643716c1",
|
||||
user_message_id="f43a93b7",
|
||||
parent_message_id="4a24530b",
|
||||
text="Who were the 8 presidents before George Washington?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="2e4e1e6",
|
||||
user_post_id="c886920",
|
||||
parent_post_id="6f1d0711",
|
||||
DummyMessage(
|
||||
task_message_id="2e4e1e6",
|
||||
user_message_id="c886920",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hey buddy! How can I serve you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="970c437d",
|
||||
user_post_id="cec432cf",
|
||||
parent_post_id=None,
|
||||
DummyMessage(
|
||||
task_message_id="970c437d",
|
||||
user_message_id="cec432cf",
|
||||
parent_message_id=None,
|
||||
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="6066118e",
|
||||
user_post_id="4f85f637",
|
||||
parent_post_id="cec432cf",
|
||||
DummyMessage(
|
||||
task_message_id="6066118e",
|
||||
user_message_id="4f85f637",
|
||||
parent_message_id="cec432cf",
|
||||
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="ba87780d",
|
||||
user_post_id="0e276b98",
|
||||
parent_post_id="cec432cf",
|
||||
DummyMessage(
|
||||
task_message_id="ba87780d",
|
||||
user_message_id="0e276b98",
|
||||
parent_message_id="cec432cf",
|
||||
text="I'm unsure how to interpret this. Is it a riddle?",
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
|
||||
for p in dummy_posts:
|
||||
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
|
||||
if wp and not wp.ack:
|
||||
logger.warning("Deleting unacknowledged seed data work package")
|
||||
db.delete(wp)
|
||||
wp = None
|
||||
if not wp:
|
||||
if p.parent_post_id is None:
|
||||
wp = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
print("p.parent_post_id", p.parent_post_id)
|
||||
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
|
||||
wp = pr.store_task(
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
)
|
||||
),
|
||||
thread_id=parent_post.thread_id,
|
||||
parent_post_id=parent_post.id,
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_post_id(wp.id, p.task_post_id)
|
||||
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
|
||||
pr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
|
||||
logger.info(
|
||||
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data work_package found: {wp.id}")
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
@@ -178,3 +179,33 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
return json.dumps(app.openapi())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--print-openapi-schema",
|
||||
help="Dumps the openapi schema to stdout",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server")
|
||||
parser.add_argument("--port", help="The port to run the server")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.print_openapi_schema:
|
||||
print(get_openapi_schema())
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -4,7 +4,7 @@ from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi import Depends, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
@@ -64,3 +64,24 @@ def api_auth(
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
def get_api_client(
|
||||
api_key: APIKey = Depends(get_api_key),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return api_auth(api_key, db)
|
||||
|
||||
|
||||
def get_trusted_api_client(
|
||||
api_key: APIKey = Depends(get_api_key),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
client = api_auth(api_key, db)
|
||||
if not client.trusted:
|
||||
raise OasstError(
|
||||
"Forbidden",
|
||||
error_code=OasstErrorCode.API_CLIENT_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
return client
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import tasks, text_labels
|
||||
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, stats, tasks, text_labels, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
|
||||
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
|
||||
api_router.include_router(frontend_messages.router, prefix="/frontend_messages", tags=["frontend_messages"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexpected message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{username}/messages")
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{username}/messages")
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,149 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
allow_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if allow_deleted else False,
|
||||
)
|
||||
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexptcted message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_stats()
|
||||
@@ -18,8 +18,8 @@ router = APIRouter()
|
||||
def generate_task(
|
||||
request: protocol_schema.TaskRequest, pr: PromptRepository
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
thread_id = None
|
||||
parent_post_id = None
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
match request.type:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
@@ -54,38 +54,42 @@ def generate_task(
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
posts = pr.fetch_random_conversation("assistant")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
case protocol_schema.TaskRequestType.prompter_reply:
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
posts = pr.fetch_random_conversation("user")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
messages = pr.fetch_random_conversation("prompter")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
|
||||
posts = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
messages = [
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
@@ -93,18 +97,18 @@ def generate_task(
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=messages,
|
||||
messages=task_messages,
|
||||
),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
|
||||
|
||||
messages = [
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
@@ -113,7 +117,7 @@ def generate_task(
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(messages=messages),
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
replies=replies,
|
||||
)
|
||||
case _:
|
||||
@@ -121,7 +125,7 @@ def generate_task(
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, thread_id, parent_post_id
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@@ -138,8 +142,8 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
task, thread_id, parent_post_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_post_id, request.collective)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -149,8 +153,8 @@ def request_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
@router.post("/{task_id}/ack", response_model=None)
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -166,20 +170,19 @@ def acknowledge_task(
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
|
||||
# here we store the post id in the database for the task
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack")
|
||||
def acknowledge_task_failure(
|
||||
@router.post("/{task_id}/nack", response_model=None)
|
||||
def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -200,8 +203,8 @@ def acknowledge_task_failure(
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -216,29 +219,31 @@ def post_interaction(
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRating:
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the rating in the database
|
||||
pr.store_rating(interaction)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRanking:
|
||||
case protocol_schema.MessageRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# TODO: check if the ranking is valid
|
||||
@@ -262,5 +267,5 @@ def close_collective_task(
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.post_id)
|
||||
pr.close_task(close_task_request.message_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages")
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
max_count: int = Query(10, gt=0, le=1000),
|
||||
start_date: datetime.datetime = None,
|
||||
end_date: datetime.datetime = None,
|
||||
only_roots: bool = False,
|
||||
desc: bool = True,
|
||||
include_deleted: bool = False,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.query_messages(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
desc=desc,
|
||||
limit=max_count,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{user_id}/messages")
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_message(m: Message) -> protocol.Message:
|
||||
if not isinstance(m.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return protocol.Message(
|
||||
id=m.id,
|
||||
parent_id=m.parent_id,
|
||||
text=m.payload.payload.text,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
)
|
||||
|
||||
|
||||
def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
return [prepare_message(m) for m in messages]
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
|
||||
|
||||
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(prepare_message(message))
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
SERVER_ERROR = 3
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
@@ -27,21 +28,22 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
MESSAGE_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
WORK_PACKAGE_NOT_COLLECTIVE = 2106
|
||||
INVALID_MESSAGE = 2008
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
TASK_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
TASK_ALREADY_UPDATED = 2103
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
|
||||
@@ -3,7 +3,7 @@ import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Journal, Task, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
@@ -14,71 +14,71 @@ class JournalEventType(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
user_created = "user_created"
|
||||
text_reply_to_post = "text_reply_to_post"
|
||||
post_rating = "post_rating"
|
||||
post_ranking = "post_ranking"
|
||||
text_reply_to_message = "text_reply_to_message"
|
||||
message_rating = "message_rating"
|
||||
message_ranking = "message_ranking"
|
||||
|
||||
|
||||
@payload_type
|
||||
class JournalEvent(BaseModel):
|
||||
type: str
|
||||
person_id: Optional[UUID]
|
||||
post_id: Optional[UUID]
|
||||
workpackage_id: Optional[UUID]
|
||||
user_id: Optional[UUID]
|
||||
message_id: Optional[UUID]
|
||||
task_id: Optional[UUID]
|
||||
task_type: Optional[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class TextReplyEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
|
||||
type: Literal[JournalEventType.text_reply_to_message] = JournalEventType.text_reply_to_message
|
||||
length: int
|
||||
role: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
|
||||
type: Literal[JournalEventType.message_rating] = JournalEventType.message_rating
|
||||
rating: int
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
|
||||
type: Literal[JournalEventType.message_ranking] = JournalEventType.message_ranking
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
class JournalWriter:
|
||||
def __init__(self, db: Session, api_client: ApiClient, person: Person):
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: User):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = person
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.user = user
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
|
||||
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_post,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_message,
|
||||
payload=TextReplyEvent(role=role, length=length),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
|
||||
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_rating,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_rating,
|
||||
payload=RatingEvent(rating=rating),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
|
||||
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_ranking,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_ranking,
|
||||
payload=RankingEvent(ranking=ranking),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log(
|
||||
@@ -87,8 +87,8 @@ class JournalWriter:
|
||||
payload: JournalEvent,
|
||||
task_type: str,
|
||||
event_type: str = None,
|
||||
workpackage_id: Optional[UUID] = None,
|
||||
post_id: Optional[UUID] = None,
|
||||
task_id: Optional[UUID] = None,
|
||||
message_id: Optional[UUID] = None,
|
||||
commit: bool = True,
|
||||
) -> Journal:
|
||||
if event_type is None:
|
||||
@@ -97,22 +97,22 @@ class JournalWriter:
|
||||
else:
|
||||
event_type = type(payload).__name__
|
||||
|
||||
if payload.person_id is None:
|
||||
payload.person_id = self.person_id
|
||||
if payload.post_id is None:
|
||||
payload.post_id = post_id
|
||||
if payload.workpackage_id is None:
|
||||
payload.workpackage_id = workpackage_id
|
||||
if payload.user_id is None:
|
||||
payload.user_id = self.user_id
|
||||
if payload.message_id is None:
|
||||
payload.message_id = message_id
|
||||
if payload.task_id is None:
|
||||
payload.task_id = task_id
|
||||
if payload.task_type is None:
|
||||
payload.task_type = task_type
|
||||
|
||||
entry = Journal(
|
||||
person_id=self.person_id,
|
||||
user_id=self.user_id,
|
||||
api_client_id=self.api_client.id,
|
||||
created_date=utcnow(),
|
||||
event_type=event_type,
|
||||
event_payload=PayloadContainer(payload=payload),
|
||||
post_id=post_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .message import Message
|
||||
from .message_reaction import MessageReaction
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .work_package import WorkPackage
|
||||
from .user import User
|
||||
from .user_stats import UserStats
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"Person",
|
||||
"PersonStats",
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"User",
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageReaction",
|
||||
"Task",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
|
||||
@@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
class PrompterReplyPayload(TaskPayload):
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
@@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
class MessagePayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@@ -56,13 +56,13 @@ class ReactionPayload(BaseModel):
|
||||
|
||||
@payload_type
|
||||
class RatingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
type: Literal["message_rating"] = "message_rating"
|
||||
rating: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
@@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -33,8 +33,8 @@ class Journal(SQLModel, table=True):
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
event_type: str = Field(nullable=False, max_length=200)
|
||||
|
||||
@@ -5,14 +5,15 @@ from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class Post(SQLModel, table=True):
|
||||
__tablename__ = "post"
|
||||
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
|
||||
class Message(SQLModel, table=True):
|
||||
__tablename__ = "message"
|
||||
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
@@ -20,12 +21,12 @@ class Post(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
thread_id: UUID = Field(nullable=False, index=True)
|
||||
workpackage_id: UUID = Field(nullable=True, index=True)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_post_id: str = Field(max_length=200, nullable=False)
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
@@ -34,3 +35,4 @@ class Post(SQLModel, table=True):
|
||||
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
+6
-6
@@ -10,14 +10,14 @@ from sqlmodel import Field, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class PostReaction(SQLModel, table=True):
|
||||
__tablename__ = "post_reaction"
|
||||
class MessageReaction(SQLModel, table=True):
|
||||
__tablename__ = "message_reaction"
|
||||
|
||||
work_package_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
|
||||
task_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
person_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
|
||||
user_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
@@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class WorkPackage(SQLModel, table=True):
|
||||
__tablename__ = "work_package"
|
||||
class Task(SQLModel, table=True):
|
||||
__tablename__ = "task"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
@@ -23,15 +23,15 @@ class WorkPackage(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
ack: Optional[bool] = None
|
||||
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
frontend_ref_post_id: Optional[str] = None
|
||||
thread_id: Optional[UUID] = None
|
||||
parent_post_id: Optional[UUID] = None
|
||||
frontend_message_id: Optional[str] = None
|
||||
message_tree_id: Optional[UUID] = None
|
||||
parent_message_id: Optional[UUID] = None
|
||||
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
@property
|
||||
@@ -21,5 +21,7 @@ class TextLabels(SQLModel, table=True):
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
|
||||
)
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
|
||||
@@ -8,9 +8,9 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class Person(SQLModel, table=True):
|
||||
__tablename__ = "person"
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "user"
|
||||
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
+8
-8
@@ -8,11 +8,11 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PersonStats(SQLModel, table=True):
|
||||
__tablename__ = "person_stats"
|
||||
class UserStats(SQLModel, table=True):
|
||||
__tablename__ = "user_stats"
|
||||
|
||||
person_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
@@ -20,9 +20,9 @@ class PersonStats(SQLModel, table=True):
|
||||
)
|
||||
|
||||
reactions: int = 0 # reactions sent by user
|
||||
posts: int = 0 # posts sent by user
|
||||
messages: int = 0 # messages sent by user
|
||||
upvotes: int = 0 # received upvotes (form other users)
|
||||
downvotes: int = 0 # received downvotes (from other users)
|
||||
work_reward: int = 0 # reward for workpackage completions
|
||||
compare_wins: int = 0 # num times user's post won compare tasks
|
||||
compare_losses: int = 0 # num times users's post lost compare tasks
|
||||
task_reward: int = 0 # reward for task completions
|
||||
compare_wins: int = 0 # num times user's message won compare tasks
|
||||
compare_losses: int = 0 # num times users's message lost compare tasks
|
||||
@@ -1,5 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -7,257 +10,259 @@ import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.journal = JournalWriter(db, api_client, self.person)
|
||||
self.user = self.lookup_user(user)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
Person.api_client_id == self.api_client.id,
|
||||
Person.username == user.id,
|
||||
Person.auth_method == user.auth_method,
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if person is None:
|
||||
if user is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(
|
||||
username=user.id,
|
||||
display_name=user.display_name,
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=user.auth_method,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(person)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
elif user.display_name and user.display_name != person.display_name:
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
person.display_name = user.display_name
|
||||
self.db.add(person)
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return person
|
||||
return user
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
|
||||
if not post_id:
|
||||
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.frontend_ref_post_id = post_id
|
||||
work_pack.ack = True
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.ack = False
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return post
|
||||
if fail_if_missing and message is None:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.",
|
||||
OasstErrorCode.MESSAGE_NOT_FOUND,
|
||||
HTTP_404_NOT_FOUND,
|
||||
)
|
||||
return message
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
self.validate_post_id(post_id)
|
||||
work_pack = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
self.validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return work_pack
|
||||
return task
|
||||
|
||||
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
|
||||
self.validate_post_id(post_id)
|
||||
self.validate_post_id(user_post_id)
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if wp is None:
|
||||
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not wp.ack:
|
||||
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
|
||||
if wp.done:
|
||||
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
# If there's no parent post assume user started new conversation
|
||||
role = "user"
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "prompter"
|
||||
depth = 0
|
||||
|
||||
if wp.parent_post_id:
|
||||
parent_post = self.fetch_post(wp.parent_post_id)
|
||||
parent_post.children_count += 1
|
||||
self.db.add(parent_post)
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
parent_message.children_count += 1
|
||||
self.db.add(parent_message)
|
||||
|
||||
depth = parent_post.depth + 1
|
||||
if parent_post.role == "assistant":
|
||||
role = "user"
|
||||
depth = parent_message.depth + 1
|
||||
if parent_message.role == "assistant":
|
||||
role = "prompter"
|
||||
else:
|
||||
role = "assistant"
|
||||
|
||||
# create reply post
|
||||
new_post_id = uuid4()
|
||||
user_post = self.insert_post(
|
||||
post_id=new_post_id,
|
||||
frontend_post_id=user_post_id,
|
||||
parent_id=wp.parent_post_id,
|
||||
thread_id=wp.thread_id or new_post_id,
|
||||
workpackage_id=wp.id,
|
||||
# create reply message
|
||||
new_message_id = uuid4()
|
||||
user_message = self.insert_message(
|
||||
message_id=new_message_id,
|
||||
frontend_message_id=user_frontend_message_id,
|
||||
parent_id=task.parent_message_id,
|
||||
message_tree_id=task.message_tree_id or new_message_id,
|
||||
task_id=task.id,
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=text),
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
)
|
||||
if not wp.collective:
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
|
||||
return user_post
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
return user_message
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
task = self.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"Task payload type mismatch: {type(task_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
if rating.rating < task_payload.scale.min or rating.rating > task_payload.scale.max:
|
||||
raise OasstError(
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
|
||||
f"Invalid rating value: {rating.rating=} not in {task_payload.scale=}",
|
||||
OasstErrorCode.RATING_OUT_OF_RANGE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
reaction = self.insert_reaction(message.id, reaction_payload)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch task
|
||||
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
task_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
task.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
match type(task_payload):
|
||||
|
||||
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
num_replies = len(task_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
@@ -271,8 +276,8 @@ class PromptRepository:
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
@@ -280,8 +285,8 @@ class PromptRepository:
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
@@ -293,81 +298,85 @@ class PromptRepository:
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(
|
||||
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_work_package(
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
task = Task(
|
||||
id=id,
|
||||
person_id=self.person_id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_post(
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
post_id: UUID,
|
||||
frontend_post_id: str,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
thread_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.PostPayload,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Post:
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
post = Post(
|
||||
id=post_id,
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
thread_id=thread_id,
|
||||
workpackage_id=workpackage_id,
|
||||
person_id=self.person_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_post_id=frontend_post_id,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
work_package_id=work_package_id,
|
||||
person_id=self.person_id,
|
||||
reaction = MessageReaction(
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
@@ -383,108 +392,317 @@ class PromptRepository:
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
if text_labels.has_post_id:
|
||||
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
|
||||
model.post_id = text_labels.post_id
|
||||
if text_labels.has_message_id:
|
||||
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
|
||||
model.message_id = text_labels.message_id
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
|
||||
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
|
||||
"""
|
||||
Loads all posts of a random thread.
|
||||
Loads all messages of a random message_tree.
|
||||
|
||||
:param require_role: If set loads only thread which has
|
||||
at least one post with given role.
|
||||
:param require_role: If set loads only message_tree which has
|
||||
at least one message with given role.
|
||||
"""
|
||||
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
|
||||
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
|
||||
if require_role:
|
||||
distinct_threads = distinct_threads.filter(Post.role == require_role)
|
||||
distinct_threads = distinct_threads.subquery()
|
||||
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
|
||||
distinct_message_trees = distinct_message_trees.subquery()
|
||||
|
||||
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
|
||||
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
|
||||
return thread_posts
|
||||
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
|
||||
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
|
||||
return message_tree_messages
|
||||
|
||||
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
|
||||
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
|
||||
"""
|
||||
Picks a random linear conversation starting from any root post
|
||||
and ending somewhere in the thread, possibly at the root itself.
|
||||
Picks a random linear conversation starting from any root message
|
||||
and ending somewhere in the message_tree, possibly at the root itself.
|
||||
|
||||
:param last_post_role: If set will form a conversation ending with a post
|
||||
:param last_message_role: If set will form a conversation ending with a message
|
||||
created by this role. Necessary for the tasks like "user_reply" where
|
||||
the user should reply as a human and hence the last message of the conversation
|
||||
needs to have "assistant" role.
|
||||
"""
|
||||
thread_posts = self.fetch_random_thread(last_post_role)
|
||||
if not thread_posts:
|
||||
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
|
||||
if last_post_role:
|
||||
conv_posts = [p for p in thread_posts if p.role == last_post_role]
|
||||
conv_posts = [random.choice(conv_posts)]
|
||||
messages_tree = self.fetch_random_message_tree(last_message_role)
|
||||
if not messages_tree:
|
||||
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
|
||||
if last_message_role:
|
||||
conv_messages = [m for m in messages_tree if m.role == last_message_role]
|
||||
conv_messages = [random.choice(conv_messages)]
|
||||
else:
|
||||
conv_posts = [random.choice(thread_posts)]
|
||||
thread_posts = {p.id: p for p in thread_posts}
|
||||
conv_messages = [random.choice(messages_tree)]
|
||||
messages_tree = {m.id: m for m in messages_tree}
|
||||
|
||||
while True:
|
||||
if not conv_posts[-1].parent_id:
|
||||
if not conv_messages[-1].parent_id:
|
||||
# reached the start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread_posts[conv_posts[-1].parent_id]
|
||||
conv_posts.append(parent_post)
|
||||
parent_message = messages_tree[conv_messages[-1].parent_id]
|
||||
conv_messages.append(parent_message)
|
||||
|
||||
return list(reversed(conv_posts))
|
||||
return list(reversed(conv_messages))
|
||||
|
||||
def fetch_random_initial_prompts(self, size: int = 5):
|
||||
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return posts
|
||||
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return messages
|
||||
|
||||
def fetch_thread(self, thread_id: UUID):
|
||||
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||
def fetch_message_tree(self, message_tree_id: UUID):
|
||||
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
|
||||
parent = self.db.query(Post.id).filter(Post.children_count > 1)
|
||||
if post_role:
|
||||
parent = parent.filter(Post.role == post_role)
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
"""
|
||||
Fetch a conversation with multiple possible replies to it.
|
||||
|
||||
This function finds a random message with >1 replies,
|
||||
forms a conversation from the corresponding message tree root up to this message
|
||||
and fetches up to max_size possible replies in continuation to this conversation.
|
||||
"""
|
||||
parent = self.db.query(Message.id).filter(Message.children_count > 1)
|
||||
if message_role:
|
||||
parent = parent.filter(Message.role == message_role)
|
||||
|
||||
parent = parent.order_by(func.random()).limit(1)
|
||||
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
replies = (
|
||||
self.db.query(Message).filter(Message.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
)
|
||||
if not replies:
|
||||
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
|
||||
|
||||
thread = self.fetch_thread(replies[0].thread_id)
|
||||
thread = {p.id: p for p in thread}
|
||||
thread_posts = [thread[replies[0].parent_id]]
|
||||
message_tree = self.fetch_message_tree(replies[0].message_tree_id)
|
||||
message_tree = {p.id: p for p in message_tree}
|
||||
conversation = [message_tree[replies[0].parent_id]]
|
||||
while True:
|
||||
if not thread_posts[-1].parent_id:
|
||||
if not conversation[-1].parent_id:
|
||||
# reached start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread[thread_posts[-1].parent_id]
|
||||
thread_posts.append(parent_post)
|
||||
parent_message = message_tree[conversation[-1].parent_id]
|
||||
conversation.append(parent_message)
|
||||
|
||||
thread_posts = reversed(thread_posts)
|
||||
conversation = reversed(conversation)
|
||||
|
||||
return thread_posts, replies
|
||||
return conversation, replies
|
||||
|
||||
def fetch_post(self, post_id: UUID) -> Optional[Post]:
|
||||
return self.db.query(Post).filter(Post.id == post_id).one()
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_post_id(post_id)
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not wp:
|
||||
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not allow_personal_tasks and not wp.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
|
||||
if wp.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
Pick messages from a collection so that the result makes a linear conversation
|
||||
starting from a message tree root and up to the given message.
|
||||
Returns an ordered list of messages starting from the message tree root.
|
||||
"""
|
||||
if isinstance(messages, list):
|
||||
messages = {m.id: m for m in messages}
|
||||
if not isinstance(messages, dict):
|
||||
# This should not normally happen
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
conv = [last_message]
|
||||
while conv[-1].parent_id:
|
||||
if conv[-1].parent_id not in messages:
|
||||
# Can't form a continuous conversation
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
parent_message = messages[conv[-1].parent_id]
|
||||
conv.append(parent_message)
|
||||
|
||||
return list(reversed(conv))
|
||||
|
||||
def fetch_message_conversation(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Fetch a conversation from the tree root and up to this message.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_message(message)
|
||||
|
||||
tree_messages = self.fetch_message_tree(message.message_tree_id)
|
||||
return self.trace_conversation(tree_messages, message)
|
||||
|
||||
def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Fetch message tree this message belongs to.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_message(message)
|
||||
return self.fetch_message_tree(message.message_tree_id)
|
||||
|
||||
def fetch_message_children(self, message: Message | UUID) -> list[Message]:
|
||||
"""
|
||||
Get all direct children of this message
|
||||
"""
|
||||
if isinstance(message, Message):
|
||||
message = message.id
|
||||
|
||||
children = self.db.query(Message).filter(Message.parent_id == message).all()
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
def trace_descendants(root: Message, messages: list[Message]) -> list[Message]:
|
||||
children = defaultdict(list)
|
||||
for msg in messages:
|
||||
children[msg.parent_id].append(msg)
|
||||
|
||||
def _traverse_subtree(m: Message):
|
||||
for child in children[m.id]:
|
||||
yield child
|
||||
yield from _traverse_subtree(child)
|
||||
|
||||
return list(_traverse_subtree(root))
|
||||
|
||||
def fetch_message_descendants(self, message: Message | UUID, max_depth: int = None) -> list[Message]:
|
||||
"""
|
||||
Find all descendant messages to this message.
|
||||
|
||||
This function creates a subtree of messages starting from given root message.
|
||||
"""
|
||||
if isinstance(message, UUID):
|
||||
message = self.fetch_message(message)
|
||||
|
||||
desc = self.db.query(Message).filter(
|
||||
Message.message_tree_id == message.message_tree_id, Message.depth > message.depth
|
||||
)
|
||||
if max_depth is not None:
|
||||
desc = desc.filter(Message.depth <= max_depth)
|
||||
|
||||
desc = desc.all()
|
||||
|
||||
return self.trace_descendants(message, desc)
|
||||
|
||||
def fetch_longest_conversation(self, message: Message | UUID) -> list[Message]:
|
||||
tree = self.fetch_tree_from_message(message)
|
||||
max_message = max(tree, key=lambda m: m.depth)
|
||||
return self.trace_conversation(tree, max_message)
|
||||
|
||||
def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Message, list[Message]]:
|
||||
tree = self.fetch_tree_from_message(message)
|
||||
max_message = max(tree, key=lambda m: m.children_count)
|
||||
return max_message, [m for m in tree if m.parent_id == max_message.id]
|
||||
|
||||
def query_messages(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
username: Optional[str] = None,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
desc: bool = True,
|
||||
limit: Optional[int] = 10,
|
||||
start_date: Optional[datetime.datetime] = None,
|
||||
end_date: Optional[datetime.datetime] = None,
|
||||
only_roots: bool = False,
|
||||
deleted: Optional[bool] = None,
|
||||
) -> list[Message]:
|
||||
if not self.api_client.trusted and not api_client_id:
|
||||
# Let unprivileged api clients query their own messages without api_client_id being set
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
messages = self.db.query(Message)
|
||||
if user_id:
|
||||
messages = messages.filter(Message.user_id == user_id)
|
||||
if username:
|
||||
messages = messages.join(User)
|
||||
messages = messages.filter(User.username == username)
|
||||
if api_client_id:
|
||||
messages = messages.filter(Message.api_client_id == api_client_id)
|
||||
|
||||
if start_date:
|
||||
messages = messages.filter(Message.created_date >= start_date)
|
||||
if end_date:
|
||||
messages = messages.filter(Message.created_date < end_date)
|
||||
|
||||
if only_roots:
|
||||
messages = messages.filter(Message.parent_id.is_(None))
|
||||
|
||||
if deleted is not None:
|
||||
messages = messages.filter(Message.deleted == deleted)
|
||||
|
||||
if desc:
|
||||
messages = messages.order_by(Message.created_date.desc())
|
||||
else:
|
||||
messages = messages.order_by(Message.created_date.asc())
|
||||
|
||||
if limit is not None:
|
||||
messages = messages.limit(limit)
|
||||
|
||||
# TODO: Pagination could be great at some point
|
||||
return messages.all()
|
||||
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
Marks deleted messages and all their descendants.
|
||||
"""
|
||||
if isinstance(messages, (Message, UUID)):
|
||||
messages = [messages]
|
||||
|
||||
ids = []
|
||||
for message in messages:
|
||||
if isinstance(message, UUID):
|
||||
ids.append(message)
|
||||
elif isinstance(message, Message):
|
||||
ids.append(message.id)
|
||||
else:
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
query = update(Message).where(Message.id.in_(ids)).values(deleted=True)
|
||||
self.db.execute(query)
|
||||
|
||||
parent_ids = ids
|
||||
if recursive:
|
||||
while parent_ids:
|
||||
query = (
|
||||
update(Message).filter(Message.parent_id.in_(parent_ids)).values(deleted=True).returning(Message.id)
|
||||
)
|
||||
|
||||
parent_ids = self.db.execute(query).scalars().all()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_stats(self) -> SystemStats:
|
||||
"""
|
||||
Get data stats such as number of all messages in the system,
|
||||
number of deleted and active messages and number of message trees.
|
||||
"""
|
||||
deleted = self.db.query(Message.deleted, func.count()).group_by(Message.deleted)
|
||||
nthreads = self.db.query(None, func.count(Message.id)).filter(Message.parent_id.is_(None))
|
||||
query = deleted.union_all(nthreads)
|
||||
result = {k: v for k, v in query.all()}
|
||||
|
||||
return SystemStats(
|
||||
all=result.get(True, 0) + result.get(False, 0),
|
||||
active=result.get(False, 0),
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
+6
-6
@@ -16,8 +16,8 @@ Setup requires a few steps:
|
||||
copilot app init --domain your_domain.com
|
||||
```
|
||||
|
||||
This will initialize and register a variety of URLs with your
|
||||
`your_domain.com`. Replace with a proper domain to setup SSL certificates.
|
||||
This will initialize and register a variety of URLs with your `your_domain.com`.
|
||||
Replace with a proper domain to setup SSL certificates.
|
||||
|
||||
```sh
|
||||
copilot env deploy
|
||||
@@ -29,10 +29,10 @@ This will create a variety of aws roles and services needed for deployment.
|
||||
copilot deploy
|
||||
```
|
||||
|
||||
This will depoy the services but it won't be 100% ready for usage. Before
|
||||
being ready, we have to inspect the AWS Secrets manager and extract out the
|
||||
database credentials. Read those credentials then put them, and a few other
|
||||
secrets, in a `secrets.yml` file like the following:
|
||||
This will depoy the services but it won't be 100% ready for usage. Before being
|
||||
ready, we have to inspect the AWS Secrets manager and extract out the database
|
||||
credentials. Read those credentials then put them, and a few other secrets, in a
|
||||
`secrets.yml` file like the following:
|
||||
|
||||
```yaml
|
||||
DATABASE_URL:
|
||||
|
||||
@@ -4,14 +4,17 @@ Parameters:
|
||||
Description: Your application's name.
|
||||
Env:
|
||||
Type: String
|
||||
Description: The environment name your service, job, or workflow is being deployed to.
|
||||
Description:
|
||||
The environment name your service, job, or workflow is being deployed to.
|
||||
Name:
|
||||
Type: String
|
||||
Description: The name of the service, job, or workflow being deployed.
|
||||
# Customize your Aurora Serverless cluster by setting the default value of the following parameters.
|
||||
webclusterDBName:
|
||||
Type: String
|
||||
Description: The name of the initial database to be created in the Aurora Serverless v2 cluster.
|
||||
Description:
|
||||
The name of the initial database to be created in the Aurora Serverless v2
|
||||
cluster.
|
||||
Default: oassist_web
|
||||
# Cannot have special characters
|
||||
# Naming constraints: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html#RDS_Limits.Constraints
|
||||
@@ -29,15 +32,20 @@ Resources:
|
||||
webclusterDBSubnetGroup:
|
||||
Type: "AWS::RDS::DBSubnetGroup"
|
||||
Properties:
|
||||
DBSubnetGroupDescription: Group of Copilot private subnets for Aurora Serverless v2 cluster.
|
||||
DBSubnetGroupDescription:
|
||||
Group of Copilot private subnets for Aurora Serverless v2 cluster.
|
||||
SubnetIds:
|
||||
!Split [",", { "Fn::ImportValue": !Sub "${App}-${Env}-PrivateSubnets" }]
|
||||
webclusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A security group for your workload to access the Aurora Serverless v2 cluster webcluster"
|
||||
"aws:copilot:description":
|
||||
"A security group for your workload to access the Aurora Serverless v2
|
||||
cluster webcluster"
|
||||
Type: "AWS::EC2::SecurityGroup"
|
||||
Properties:
|
||||
GroupDescription: !Sub "The Security Group for ${Name} to access Aurora Serverless v2 cluster webcluster."
|
||||
GroupDescription:
|
||||
!Sub "The Security Group for ${Name} to access Aurora Serverless v2
|
||||
cluster webcluster."
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
Tags:
|
||||
@@ -45,7 +53,8 @@ Resources:
|
||||
Value: !Sub "copilot-${App}-${Env}-${Name}-Aurora"
|
||||
webclusterDBClusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A security group for your Aurora Serverless v2 cluster webcluster"
|
||||
"aws:copilot:description":
|
||||
"A security group for your Aurora Serverless v2 cluster webcluster"
|
||||
Type: AWS::EC2::SecurityGroup
|
||||
Properties:
|
||||
GroupDescription: The Security Group for the Aurora Serverless v2 cluster.
|
||||
@@ -53,13 +62,15 @@ Resources:
|
||||
- ToPort: 5432
|
||||
FromPort: 5432
|
||||
IpProtocol: tcp
|
||||
Description: !Sub "From the Aurora Security Group of the workload ${Name}."
|
||||
Description:
|
||||
!Sub "From the Aurora Security Group of the workload ${Name}."
|
||||
SourceSecurityGroupId: !Ref webclusterSecurityGroup
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
webclusterAuroraSecret:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A Secrets Manager secret to store your DB credentials"
|
||||
"aws:copilot:description":
|
||||
"A Secrets Manager secret to store your DB credentials"
|
||||
Type: AWS::SecretsManager::Secret
|
||||
Properties:
|
||||
Description: !Sub Aurora main user secret for ${AWS::StackName}
|
||||
@@ -71,7 +82,8 @@ Resources:
|
||||
PasswordLength: 16
|
||||
webclusterDBClusterParameterGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A DB parameter group for engine configuration values"
|
||||
"aws:copilot:description":
|
||||
"A DB parameter group for engine configuration values"
|
||||
Type: "AWS::RDS::DBClusterParameterGroup"
|
||||
Properties:
|
||||
Description: !Ref "AWS::StackName"
|
||||
@@ -80,7 +92,8 @@ Resources:
|
||||
client_encoding: "UTF8"
|
||||
webclusterDBCluster:
|
||||
Metadata:
|
||||
"aws:copilot:description": "The webcluster Aurora Serverless v2 database cluster"
|
||||
"aws:copilot:description":
|
||||
"The webcluster Aurora Serverless v2 database cluster"
|
||||
Type: "AWS::RDS::DBCluster"
|
||||
Properties:
|
||||
MasterUsername:
|
||||
@@ -117,7 +130,8 @@ Resources:
|
||||
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMaxCapacity]
|
||||
webclusterDBWriterInstance:
|
||||
Metadata:
|
||||
"aws:copilot:description": "The webcluster Aurora Serverless v2 writer instance"
|
||||
"aws:copilot:description":
|
||||
"The webcluster Aurora Serverless v2 writer instance"
|
||||
Type: "AWS::RDS::DBInstance"
|
||||
Properties:
|
||||
DBClusterIdentifier: !Ref webclusterDBCluster
|
||||
@@ -137,7 +151,10 @@ Resources:
|
||||
TargetType: AWS::RDS::DBCluster
|
||||
Outputs:
|
||||
webclusterSecret: # injected as WEBCLUSTER_SECRET environment variable by Copilot.
|
||||
Description: "The JSON secret that holds the database username and password. Fields are 'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier' and 'engine'"
|
||||
Description:
|
||||
"The JSON secret that holds the database username and password. Fields are
|
||||
'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier'
|
||||
and 'engine'"
|
||||
Value: !Ref webclusterAuroraSecret
|
||||
webclusterSecurityGroup:
|
||||
Description: "The security group to attach to the workload."
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="./"
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
@@ -1,3 +1,10 @@
|
||||
.env
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
|
||||
.venv
|
||||
.nox
|
||||
.env
|
||||
|
||||
# Database files
|
||||
*.db
|
||||
|
||||
+129
-8
@@ -1,20 +1,141 @@
|
||||
# Open-Assistant Data Collection Discord Bot
|
||||
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an
|
||||
assistant chat bot based on a large language model. You and other people can
|
||||
teach the bot how to respond to user requests by demonstration and by ranking
|
||||
the bot's outputs. If you want to learn more about RLHF please refer
|
||||
[to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
|
||||
## Invite official bot
|
||||
|
||||
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
|
||||
To add the official Open-Assistant data collection bot to your discord server
|
||||
[click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot).
|
||||
The bot needs access to read the contents of user text messages.
|
||||
|
||||
## Bot token for development
|
||||
## Contributing
|
||||
|
||||
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
|
||||
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the
|
||||
[large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
|
||||
|
||||
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
|
||||
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
|
||||
### Setup
|
||||
|
||||
The simplest way to configure the token is via an `.env` file:
|
||||
To run the bot:
|
||||
|
||||
Install dependency module `oasst-shared`
|
||||
|
||||
```bash
|
||||
cd oasst-shared
|
||||
pip install -e .
|
||||
```
|
||||
BOT_TOKEN=XYZABC123...
|
||||
|
||||
```bash
|
||||
cd ../discord-bot
|
||||
cp .env.example .env
|
||||
|
||||
python -V # 3.10
|
||||
|
||||
pip install -r requirements.txt
|
||||
python -m bot
|
||||
```
|
||||
|
||||
Before you push, make sure the `pre-commit` hooks are installed and run
|
||||
successfully.
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
|
||||
...
|
||||
|
||||
git add .
|
||||
git commit -m "<good commit message>"
|
||||
# if the pre-commit fails
|
||||
git add .
|
||||
git commit -m "<good commit message>"
|
||||
```
|
||||
|
||||
To test the bot on your own discord server you need to register a discord
|
||||
application at the
|
||||
[Discord Developer Portal](https://discord.com/developers/applications) and get
|
||||
at bot token.
|
||||
|
||||
1. Follow a tutorial on how to get a bot token, for example this one:
|
||||
[Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
|
||||
2. The bot script expects the bot token to be in the `.env` file under the
|
||||
`TOKEN` variable.
|
||||
|
||||
### Resources
|
||||
|
||||
#### Structure
|
||||
|
||||
Important files
|
||||
|
||||
```graphql
|
||||
.env # Environment variables
|
||||
.env.example # Example environment variables
|
||||
CONTRIBUTING.md # This file
|
||||
README.md # Project readme
|
||||
EXAMPLES.md # Examples for commands and listeners
|
||||
requirements.txt # Requirements
|
||||
|
||||
bot/
|
||||
├─ __main__.py # Entrypoint
|
||||
├─ api_client.py # API Client for interacting with the backend
|
||||
├─ bot.py # Main bot class
|
||||
├─ settings.py # Settings and secrets
|
||||
├─ utils.py # Utility Functions
|
||||
│
|
||||
├─ db/ # Database related code
|
||||
│ ├─ database.db # SQLite database
|
||||
│ ├─ schema.sql # SQL schema
|
||||
│ └─ schemas.py # Python table schemas
|
||||
│
|
||||
└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
|
||||
├─ work.py # Task handling logic <-- most important file
|
||||
├─ guild_settings.py # Server specific settings
|
||||
└─ hot_reload.py # Utility for hot reload extensions during development
|
||||
```
|
||||
|
||||
#### Adding a new command/listener
|
||||
|
||||
1. Create a new file in the `extensions` folder
|
||||
2. Copy the template below
|
||||
|
||||
```py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""My plugin."""
|
||||
import lightbulb
|
||||
|
||||
plugin = lightbulb.Plugin("MyPlugin")
|
||||
|
||||
# Add your commands here
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
```
|
||||
|
||||
#### Docs
|
||||
|
||||
Discord
|
||||
|
||||
- [Discord API Reference](https://discord.com/developers/docs/intro)
|
||||
|
||||
`hikari` (main framework)
|
||||
|
||||
- [Hikari Repo](https://github.com/hikari-py/hikari)
|
||||
- [Hikari Docs](https://docs.hikari-py.dev/en/latest/)
|
||||
|
||||
`lightbulb` (command handler)
|
||||
|
||||
- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb)
|
||||
- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/)
|
||||
|
||||
`miru` (component handler: buttons, modals, etc... )
|
||||
|
||||
- [Miru Repo](https://github.com/HyperGH/hikari-miru)
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from bot import OpenAssistantBot
|
||||
from bot_settings import settings
|
||||
|
||||
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = OpenAssistantBot(
|
||||
settings.BOT_TOKEN,
|
||||
bot_channel_name=settings.BOT_CHANNEL_NAME,
|
||||
backend_url=settings.BACKEND_URL,
|
||||
api_key=settings.API_KEY,
|
||||
owner_id=settings.OWNER_ID,
|
||||
template_dir=settings.TEMPLATE_DIR,
|
||||
debug=settings.DEBUG,
|
||||
)
|
||||
bot.run()
|
||||
@@ -1,79 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
self.task_models_map = task_models_map
|
||||
|
||||
def post(self, path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _parse_task(self, data: dict) -> protocol_schema.Task:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("dict expected")
|
||||
|
||||
task_type = data.get("type")
|
||||
if task_type not in self.task_models_map:
|
||||
raise RuntimeError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self,
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
|
||||
|
||||
def ack_task(self, task_id: str, post_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
||||
|
||||
def nack_task(self, task_id: str, reason: str) -> None:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
@@ -1,283 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import discord
|
||||
import task_handlers
|
||||
from api_client import ApiClient, TaskType
|
||||
from bot_base import BotBase
|
||||
from discord import app_commands
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import get_git_head_hash, utcnow
|
||||
|
||||
__version__ = "0.0.3"
|
||||
BOT_NAME = "Open-Assistant Junior"
|
||||
|
||||
|
||||
class OpenAssistantBot(BotBase):
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str,
|
||||
bot_channel_name: str,
|
||||
backend_url: str,
|
||||
api_key: str,
|
||||
owner_id: Optional[Union[int, str]] = None,
|
||||
template_dir: str = "./templates",
|
||||
debug: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.template_dir = Path(template_dir)
|
||||
self.bot_channel_name = bot_channel_name
|
||||
self.templates = MessageTemplates(template_dir)
|
||||
self.debug = debug
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
if isinstance(owner_id, str):
|
||||
owner_id = int(owner_id)
|
||||
self.owner_id = owner_id
|
||||
|
||||
self.bot_token = bot_token
|
||||
client = discord.Client(intents=intents)
|
||||
self.client = client
|
||||
self.loop = client.loop
|
||||
|
||||
self.bot_channel: discord.TextChannel = None
|
||||
self.backend = ApiClient(backend_url, api_key)
|
||||
|
||||
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
||||
logger.info(f"{client.user} is now running!")
|
||||
|
||||
await self.delete_all_old_bot_messages()
|
||||
# if self.debug:
|
||||
# await self.post_boot_message()
|
||||
await self.post_welcome_message()
|
||||
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message):
|
||||
# ignore own messages
|
||||
if message.author != client.user:
|
||||
await self.handle_message(message)
|
||||
|
||||
@self.tree.command()
|
||||
async def tutorial(interaction: discord.Interaction):
|
||||
"""Start the Open-Assistant tutorial via DMs."""
|
||||
|
||||
dm = await self.client.create_dm(discord.Object(interaction.user.id))
|
||||
await dm.send("Tutorial coming soon... :-)")
|
||||
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
|
||||
|
||||
@self.tree.command()
|
||||
async def help(interaction: discord.Interaction):
|
||||
"""Sends the user a list of all available commands"""
|
||||
await self.post_help(interaction.user)
|
||||
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
|
||||
|
||||
@self.tree.command()
|
||||
async def work(interaction: discord.Interaction):
|
||||
"""Request a new personalized task"""
|
||||
|
||||
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
q = task_handlers.Questionnaire()
|
||||
await interaction.response.send_modal(q)
|
||||
|
||||
async def post_help(self, user: discord.abc.User) -> discord.Message:
|
||||
is_bot_owner = user.id == self.owner_id
|
||||
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
|
||||
|
||||
async def post_boot_message(self) -> discord.Message:
|
||||
return await self.post_template(
|
||||
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
||||
)
|
||||
|
||||
async def post_welcome_message(self) -> discord.Message:
|
||||
return await self.post_template("welcome.msg")
|
||||
|
||||
async def delete_all_old_bot_messages(self) -> None:
|
||||
logger.info("Deleting old threads...")
|
||||
for thread in self.bot_channel.threads:
|
||||
if thread.owner_id == self.client.user.id:
|
||||
await thread.delete()
|
||||
logger.info("Completed deleting old theards.")
|
||||
|
||||
logger.info("Deleting old messages...")
|
||||
look_until = utcnow() - timedelta(days=365)
|
||||
async for msg in self.bot_channel.history(limit=None):
|
||||
msg: discord.Message
|
||||
if msg.created_at < look_until:
|
||||
break
|
||||
if msg.author.id == self.client.user.id:
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old messages.")
|
||||
|
||||
async def next_task(self):
|
||||
task_type = protocol_schema.TaskRequestType.random
|
||||
task = self.backend.fetch_task(task_type, user=None)
|
||||
|
||||
handler: task_handlers.ChannelTaskBase = None
|
||||
match task.type:
|
||||
case TaskType.summarize_story:
|
||||
handler = task_handlers.SummarizeStoryHandler()
|
||||
case TaskType.rate_summary:
|
||||
handler = task_handlers.RateSummaryHandler()
|
||||
case TaskType.initial_prompt:
|
||||
handler = task_handlers.InitialPromptHandler()
|
||||
case TaskType.user_reply:
|
||||
handler = task_handlers.UserReplyHandler()
|
||||
case TaskType.assistant_reply:
|
||||
handler = task_handlers.AssistantReplyHandler()
|
||||
case TaskType.rank_initial_prompts:
|
||||
handler = task_handlers.RankInitialPromptsHandler()
|
||||
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
||||
handler = task_handlers.RankConversationsHandler()
|
||||
case _:
|
||||
logger.warning(f"Unsupported task type received: {task.type}")
|
||||
self.backend.nack_task(task.id, "not supported")
|
||||
|
||||
if handler:
|
||||
try:
|
||||
logger.info(f"strarting task {task.id}")
|
||||
msg = await handler.start(self, task)
|
||||
self.backend.ack_task(task.id, msg.id)
|
||||
except Exception:
|
||||
logger.exception("Starting task failed.")
|
||||
self.backend.nack_task(task.id, "faled")
|
||||
|
||||
async def background_timer(self):
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
next_fetch_task = utcnow() + timedelta(seconds=1)
|
||||
while True:
|
||||
now = utcnow()
|
||||
|
||||
if self.bot_channel:
|
||||
if now > next_fetch_task:
|
||||
next_fetch_task = utcnow() + timedelta(seconds=60)
|
||||
|
||||
try:
|
||||
await self.next_task()
|
||||
except Exception:
|
||||
logger.exception("fetching next task failed")
|
||||
|
||||
for x in self.reply_handlers.values():
|
||||
x.handler.tick(now)
|
||||
|
||||
if now > next_remove_completed:
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
await self.remove_completed_handlers()
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _sync(self, command: str, message: discord.Message):
|
||||
|
||||
logger.info(f"sync tree command received: {command}")
|
||||
|
||||
if command == "sync.copy_global":
|
||||
await self.tree.copy_global_to(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.clear_guild":
|
||||
self.tree.clear_commands(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.guild":
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
else:
|
||||
synced = await self.tree.sync()
|
||||
|
||||
logger.info(f"Synced {len(synced)} commands")
|
||||
await message.reply(f"Synced {len(synced)} commands")
|
||||
|
||||
async def handle_command(self, message: discord.Message, is_owner: bool):
|
||||
command_text: str = message.content
|
||||
command_text = command_text[1:]
|
||||
match command_text:
|
||||
case "help" | "?":
|
||||
await self.post_help(user=message.author)
|
||||
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
|
||||
if is_owner:
|
||||
await self._sync(command_text, message)
|
||||
case _:
|
||||
await message.reply(f"unknown command: {command_text}")
|
||||
|
||||
def recipient_filter(self, message: discord.Message) -> bool:
|
||||
channel = message.channel
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.private
|
||||
or message.channel.type == discord.ChannelType.private_thread
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.text
|
||||
or message.channel.type == discord.ChannelType.public_thread
|
||||
):
|
||||
while channel:
|
||||
if self.bot_channel and channel.id == self.bot_channel.id:
|
||||
return True
|
||||
channel = channel.parent
|
||||
|
||||
return False
|
||||
|
||||
async def handle_message(self, message: discord.Message):
|
||||
if not self.recipient_filter(message):
|
||||
return
|
||||
|
||||
user_id = message.author.id
|
||||
user_display_name = message.author.name
|
||||
|
||||
logger.debug(
|
||||
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
|
||||
)
|
||||
|
||||
command_prefix = "!"
|
||||
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
|
||||
is_owner = self.owner_id and user_id == self.owner_id
|
||||
await self.handle_command(message, is_owner)
|
||||
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
handler = self.reply_handlers.get(message.channel.id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
if message.reference:
|
||||
handler = self.reply_handlers.get(message.reference.message_id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
async def remove_completed_handlers(self):
|
||||
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
|
||||
if len(completed) == 0:
|
||||
return
|
||||
|
||||
for c in completed:
|
||||
handler = self.reply_handlers[c]
|
||||
del self.reply_handlers[c]
|
||||
try:
|
||||
await handler.handler.finalize()
|
||||
except Exception:
|
||||
logger.exception("handler finalize failed")
|
||||
|
||||
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
|
||||
|
||||
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
||||
for channel in self.client.get_all_channels():
|
||||
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
||||
return channel
|
||||
|
||||
def run(self):
|
||||
"""Run bot loop blocking."""
|
||||
self.client.run(self.bot_token)
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The official Open-Assistant Discord Bot."""
|
||||
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Entry point for the bot."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bot.bot import bot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.name != "nt":
|
||||
import uvloop
|
||||
|
||||
uvloop.install()
|
||||
|
||||
logger.info("Starting bot")
|
||||
bot.run()
|
||||
@@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
import enum
|
||||
import typing as t
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
class TaskType(str, enum.Enum):
|
||||
"""Task types."""
|
||||
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
----
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
self.task_models_map: dict[TaskType, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
|
||||
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
|
||||
if data is None:
|
||||
raise Exception("Cannot parse data as a task: data is none")
|
||||
task_type = TaskType(data.get("type"))
|
||||
|
||||
model = self.task_models_map.get(task_type)
|
||||
if not model:
|
||||
logger.error(f"Unsupported task type: {task_type}")
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
return self.task_models_map[task_type].parse_obj(data) # type: ignore
|
||||
|
||||
async def fetch_task(
|
||||
self,
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a task from the backend."""
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||
logger.debug(f"RESP {resp}")
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a random task from the backend."""
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
|
||||
"""Send an ACK for a task to the backend."""
|
||||
logger.debug(f"ACK task {task_id} with post {message_id}")
|
||||
req = protocol_schema.TaskAck(message_id=message_id)
|
||||
await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
|
||||
|
||||
async def nack_task(self, task_id: str | UUID, reason: str) -> None:
|
||||
"""Send a NACK for a task to the backend."""
|
||||
logger.debug(f"NACK task {task_id} with reason {reason}")
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
"""Send a completed task to the backend."""
|
||||
logger.debug(f"Interaction: {interaction}")
|
||||
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def close(self):
|
||||
logger.debug("Closing OasstApiClient session")
|
||||
await self.session.close()
|
||||
@@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Bot logic."""
|
||||
from datetime import datetime
|
||||
|
||||
import aiosqlite
|
||||
import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from bot.api_client import OasstApiClient
|
||||
from bot.settings import Settings
|
||||
from bot.utils import EMPTY, mention
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# TODO: Revisit cache settings
|
||||
bot = lightbulb.BotApp(
|
||||
token=settings.bot_token,
|
||||
logs="DEBUG",
|
||||
prefix=settings.prefix,
|
||||
default_enabled_guilds=settings.declare_global_commands,
|
||||
owner_ids=settings.owner_ids,
|
||||
intents=hikari.Intents.ALL,
|
||||
)
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_starting(event: hikari.StartingEvent):
|
||||
"""Setup."""
|
||||
miru.install(bot) # component handler
|
||||
bot.load_extensions_from("./bot/extensions") # load extensions
|
||||
|
||||
bot.d.db = await aiosqlite.connect("./bot/db/database.db")
|
||||
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
|
||||
await bot.d.db.commit()
|
||||
|
||||
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_stopping(event: hikari.StoppingEvent):
|
||||
"""Cleanup."""
|
||||
await bot.d.db.close()
|
||||
await bot.d.oasst_api.close()
|
||||
|
||||
|
||||
async def _send_error_embed(
|
||||
content: str, exception: lightbulb.errors.LightbulbError | BaseException, ctx: lightbulb.Context
|
||||
) -> None:
|
||||
ctx.command
|
||||
embed = hikari.Embed(
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }",
|
||||
description=content,
|
||||
color=0xFF0000,
|
||||
timestamp=datetime.now().astimezone(),
|
||||
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
|
||||
|
||||
await ctx.respond(EMPTY, embed=embed)
|
||||
|
||||
|
||||
@bot.listen(lightbulb.CommandErrorEvent)
|
||||
async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
"""Error handler for the bot."""
|
||||
# Unwrap the exception to get the original cause
|
||||
exc = event.exception.__cause__ or event.exception
|
||||
ctx = event.context
|
||||
|
||||
if isinstance(event.exception, lightbulb.CommandInvocationError):
|
||||
if not event.context.command:
|
||||
await _send_error_embed("Something went wrong", exc, ctx)
|
||||
else:
|
||||
await _send_error_embed(
|
||||
f"Something went wrong during invocation of command `{event.context.command.name}`.", exc, ctx
|
||||
)
|
||||
|
||||
raise event.exception
|
||||
|
||||
# Not an owner
|
||||
if isinstance(exc, lightbulb.NotOwner):
|
||||
await _send_error_embed("You are not the owner of this bot.", exc, ctx)
|
||||
# Command is on cooldown
|
||||
elif isinstance(exc, lightbulb.CommandIsOnCooldown):
|
||||
await _send_error_embed(f"This command is on cooldown. Retry in `{exc.retry_after:.2f}` seconds.", exc, ctx)
|
||||
# Missing permissions
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredPermission):
|
||||
await _send_error_embed(
|
||||
f"You do not have permission to use this command. Missing permissions: {exc.missing_perms}", exc, ctx
|
||||
)
|
||||
# Missing roles
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredRole):
|
||||
assert event.context.guild_id is not None # Roles only exist in guilds
|
||||
await _send_error_embed(
|
||||
f"You do not have the correct role to use this command. Missing role(s): {[mention(r, 'role') for r in exc.missing_roles]}",
|
||||
exc,
|
||||
ctx,
|
||||
)
|
||||
# Only a guild command
|
||||
elif isinstance(exc, lightbulb.errors.OnlyInGuild):
|
||||
await _send_error_embed("This command can only be run in servers.", exc, ctx)
|
||||
# Only a DM command
|
||||
elif isinstance(exc, lightbulb.errors.OnlyInDM):
|
||||
await _send_error_embed("This command can only be run in DMs.", exc, ctx)
|
||||
# Not enough arguments
|
||||
elif isinstance(exc, lightbulb.errors.NotEnoughArguments):
|
||||
await _send_error_embed(
|
||||
f"Not enough arguments were supplied to the command. {[opt.name for opt in exc.missing_options]}", exc, ctx
|
||||
)
|
||||
# Bot missing permission
|
||||
elif isinstance(exc, lightbulb.errors.BotMissingRequiredPermission):
|
||||
await _send_error_embed(
|
||||
f"The bot does not have the correct permission(s) to execute this command. Missing permissions: {exc.missing_perms}",
|
||||
exc,
|
||||
ctx,
|
||||
)
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
|
||||
await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx)
|
||||
else:
|
||||
raise exc
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Sqlite3 schema for the bot
|
||||
CREATE TABLE IF NOT EXISTS guild_settings (
|
||||
guild_id BIGINT NOT NULL PRIMARY KEY,
|
||||
log_channel_id BIGINT
|
||||
);
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Database schemas."""
|
||||
import typing as t
|
||||
|
||||
from aiosqlite import Connection, Row
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GuildSettings(BaseModel):
|
||||
"""Guild settings."""
|
||||
|
||||
guild_id: int
|
||||
log_channel_id: int | None
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj: Row) -> "GuildSettings":
|
||||
"""Deserialize a Row object from aiosqlite into a GuildSettings object."""
|
||||
return cls(guild_id=obj[0], log_channel_id=obj[1])
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, conn: Connection, guild_id: int) -> t.Optional["GuildSettings"]:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return cls.parse_obj(row)
|
||||
@@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Extensions for the bot.
|
||||
|
||||
See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
|
||||
"""
|
||||
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Guild settings."""
|
||||
import hikari
|
||||
import lightbulb
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import mention
|
||||
from lightbulb.utils import permissions_in
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin("GuildSettings")
|
||||
plugin.add_checks(lightbulb.guild_only)
|
||||
plugin.add_checks(lightbulb.has_guild_permissions(hikari.Permissions.MANAGE_GUILD))
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("settings", "Bot settings for the server.")
|
||||
@lightbulb.implements(lightbulb.SlashCommandGroup)
|
||||
async def settings(_: lightbulb.SlashContext) -> None:
|
||||
"""Bot settings for the server."""
|
||||
# This will never execute because it is a group
|
||||
pass
|
||||
|
||||
|
||||
@settings.child
|
||||
@lightbulb.command("get", "Get all the guild settings.")
|
||||
@lightbulb.implements(lightbulb.SlashSubCommand)
|
||||
async def get(ctx: lightbulb.SlashContext) -> None:
|
||||
"""Get one of or all the guild settings."""
|
||||
conn: Connection = ctx.bot.d.db
|
||||
assert ctx.guild_id is not None # `guild_only` check
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
# Get all settings
|
||||
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (ctx.guild_id,))
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
logger.warning(f"No guild settings for {ctx.guild_id}")
|
||||
await ctx.respond("No settings found for this guild.")
|
||||
return
|
||||
|
||||
guild_settings = GuildSettings.parse_obj(row)
|
||||
|
||||
# Respond with all
|
||||
# TODO: Embed
|
||||
await ctx.respond(
|
||||
f"""\
|
||||
**Guild Settings**
|
||||
`log_channel`: {
|
||||
mention(guild_settings.log_channel_id, "channel")
|
||||
if guild_settings.log_channel_id else 'not set'}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@settings.child
|
||||
@lightbulb.option("channel", "The channel to use.", hikari.TextableGuildChannel)
|
||||
@lightbulb.command("log_channel", "Set the channel that the bot logs task and label completions in.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashSubCommand)
|
||||
async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
"""Set the channel that the bot logs task and label completions in."""
|
||||
channel: hikari.TextableGuildChannel = ctx.options.channel
|
||||
conn: Connection = ctx.bot.d.db
|
||||
assert ctx.guild_id is not None # `guild_only` check
|
||||
|
||||
# Check if the bot can send messages in that channel
|
||||
assert isinstance(channel, hikari.InteractionChannel) # Slash commands are interactions
|
||||
me = ctx.bot.cache.get_me() or await ctx.bot.rest.fetch_my_user()
|
||||
own_member = ctx.bot.cache.get_member(ctx.guild_id, me.id) or await ctx.bot.rest.fetch_member(ctx.guild_id, me.id)
|
||||
|
||||
# Get the channel from the cache if it is there, otherwise fetch it
|
||||
if (ch := ctx.bot.cache.get_guild_channel(channel.id)) is None:
|
||||
ch = {ch.id: ch for ch in await ctx.bot.rest.fetch_guild_channels(channel.id)}[channel.id]
|
||||
|
||||
if not isinstance(ch, hikari.GuildTextChannel):
|
||||
await ctx.respond(f"{ch.mention} is not a text channel.")
|
||||
return
|
||||
|
||||
# if the bot's permissions for this channel don't contain SEND_MESSAGE
|
||||
# This will also filter out categories and voice channels
|
||||
print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES)
|
||||
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
|
||||
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
|
||||
return
|
||||
|
||||
await ctx.respond(f"Setting `log_channel` to {channel.mention}.")
|
||||
|
||||
# update the database
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT OR REPLACE INTO guild_settings (guild_id, log_channel_id) VALUES (?, ?)",
|
||||
(ctx.guild_id, channel.id),
|
||||
)
|
||||
await conn.commit()
|
||||
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
from glob import glob
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"HotReloadPlugin",
|
||||
)
|
||||
plugin.add_checks(lightbulb.owner_only)
|
||||
|
||||
EXTENSIONS_FOLDER = "bot/extensions"
|
||||
|
||||
|
||||
def _get_extensions() -> list[str]:
|
||||
# Recursively get all the .py files in the extensions directory not starting with an `_`.
|
||||
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
|
||||
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
|
||||
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
|
||||
|
||||
|
||||
async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]:
|
||||
# Check that the option is a string.
|
||||
if not isinstance(option.value, str):
|
||||
raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`")
|
||||
|
||||
exts = _get_extensions()
|
||||
return [ext for ext in exts if option.value in ext]
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"plugin",
|
||||
"The plugin to reload. Leave empty to reload all plugins.",
|
||||
autocomplete=_plugin_autocomplete,
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def reload(ctx: lightbulb.SlashContext):
|
||||
"""Reload a plugin or all plugins."""
|
||||
# If the plugin option is None, reload all plugins.
|
||||
if ctx.options.plugin is None:
|
||||
ctx.bot.reload_extensions(*_get_extensions())
|
||||
await ctx.respond("Reloaded all plugins.")
|
||||
logger.info("Reloaded all plugins.")
|
||||
# Otherwise, reload the specified plugin.
|
||||
else:
|
||||
ctx.bot.reload_extensions(ctx.options.plugin)
|
||||
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
|
||||
logger.info(f"Reloaded `{ctx.options.plugin}`.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"TextLabels",
|
||||
)
|
||||
plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds
|
||||
|
||||
|
||||
DISCORD_GRAY = 0x2F3136
|
||||
|
||||
|
||||
def clamp(num: float) -> float:
|
||||
"""Clamp a number between 0 and 1."""
|
||||
return min(max(0.0, num), 1.0)
|
||||
|
||||
|
||||
class LabelModal(miru.Modal):
|
||||
"""Modal for submitting text labels."""
|
||||
|
||||
def __init__(self, label: str, content: str, *args: t.Any, **kwargs: t.Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.label = label
|
||||
self.original_content = content
|
||||
|
||||
# Add the text of the message to the modal
|
||||
self.content = miru.TextInput(
|
||||
label="Text", style=hikari.TextInputStyle.PARAGRAPH, value=content, required=True, row=1
|
||||
)
|
||||
self.add_item(self.content)
|
||||
|
||||
value = miru.TextInput(label="Value", placeholder="Enter a value between 0 and 1", required=True, row=2)
|
||||
|
||||
async def callback(self, context: miru.ModalContext) -> None:
|
||||
val = float(self.value.value) if self.value.value else 0.0
|
||||
val = clamp(val)
|
||||
|
||||
edited = self.content.value != self.original_content
|
||||
await context.respond(
|
||||
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
)
|
||||
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
|
||||
|
||||
# Send a notification to the log channel
|
||||
assert context.guild_id is not None # `guild_only` check
|
||||
conn: Connection = context.bot.d.db # type: ignore
|
||||
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
|
||||
|
||||
if guild_settings is None or guild_settings.log_channel_id is None:
|
||||
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
|
||||
return
|
||||
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Message Label",
|
||||
description=f"{context.author.mention} labeled a message as `{self.label}`.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
color=0x00FF00,
|
||||
)
|
||||
.set_author(name=context.author.username, icon=context.author.avatar_url)
|
||||
.add_field("Total Labeled Message", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
)
|
||||
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel)
|
||||
await channel.send(EMPTY, embed=embed)
|
||||
|
||||
|
||||
class LabelSelect(miru.View):
|
||||
"""Select menu for selecting a label.
|
||||
|
||||
The current labels are:
|
||||
- contains toxic language
|
||||
- encourages illegal activity
|
||||
- good quality
|
||||
- bad quality
|
||||
- is spam
|
||||
"""
|
||||
|
||||
def __init__(self, content: str, *args: t.Any, **kwargs: t.Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.content = content
|
||||
|
||||
@miru.select(
|
||||
options=[
|
||||
hikari.SelectMenuOption(
|
||||
label="Toxic Language",
|
||||
value="toxic_language",
|
||||
description="The message contains toxic language.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Illegal Activity",
|
||||
value="illegal_activity",
|
||||
description="The message encourages illegal activity.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Good Quality",
|
||||
value="good_quality",
|
||||
description="The message is good quality.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Bad Quality",
|
||||
value="bad_quality",
|
||||
description="The message is bad quality.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Spam",
|
||||
value="spam",
|
||||
description="The message is spam.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
],
|
||||
min_values=1,
|
||||
max_values=1,
|
||||
)
|
||||
async def label_select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
|
||||
"""Handle the select menu."""
|
||||
label = select.values[0]
|
||||
modal = LabelModal(label, self.content, title=f"Text Label: {label}", timeout=60)
|
||||
await modal.send(ctx.interaction)
|
||||
await modal.wait()
|
||||
|
||||
self.stop()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("Label Message", "Label a message")
|
||||
@lightbulb.implements(lightbulb.MessageCommand)
|
||||
async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
"""Label a message."""
|
||||
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
|
||||
# so the select menu will open the modal
|
||||
|
||||
msg: hikari.Message = ctx.options.target
|
||||
# Exit if the message is empty
|
||||
if not msg.content:
|
||||
await ctx.respond("Cannot label an empty message.", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
return
|
||||
|
||||
# Send the select menu
|
||||
# The modal will be opened from the select menu interaction
|
||||
embed = hikari.Embed(title="Label Message", description="Select a label for the message.", color=DISCORD_GRAY)
|
||||
label_select_view = LabelSelect(
|
||||
msg.content,
|
||||
timeout=60,
|
||||
)
|
||||
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
await label_select_view.start(await resp.message())
|
||||
await label_select_view.wait()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task plugin for testing different data collection methods."""
|
||||
# TODO: Delete this once user input method has been decided for final bot.
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from bot.utils import format_time
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("TaskPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60
|
||||
MAX_TASK_ACCEPT_TIME = 60
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_thread(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
typ: str = ctx.options.type
|
||||
|
||||
# Create a thread for the task
|
||||
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
|
||||
|
||||
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
|
||||
|
||||
# Send the task in the thread
|
||||
await thread.send(
|
||||
f"""\
|
||||
Please complete the task.
|
||||
Sample Task
|
||||
|
||||
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
|
||||
"""
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.GuildMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
finally:
|
||||
await thread.delete()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def task_dm(ctx: lightbulb.Context):
|
||||
"""Request a task from the backend."""
|
||||
await ctx.respond("Please complete the task in your DMs")
|
||||
|
||||
# Send the task in the dm
|
||||
await ctx.author.send(
|
||||
f"""\
|
||||
Please complete the task.
|
||||
Sample Task
|
||||
|
||||
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
|
||||
"""
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
|
||||
|
||||
class TaskModal(miru.Modal):
|
||||
"""Modal for submitting a task."""
|
||||
|
||||
response = miru.TextInput(
|
||||
label="Response",
|
||||
placeholder="Enter your response!",
|
||||
required=True,
|
||||
style=hikari.TextInputStyle.PARAGRAPH,
|
||||
row=2,
|
||||
)
|
||||
|
||||
async def callback(self, context: miru.ModalContext) -> None:
|
||||
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
class ModalView(miru.View):
|
||||
"""View for opening a modal."""
|
||||
|
||||
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.modal_title = modal_title
|
||||
self.task = task
|
||||
|
||||
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
modal = TaskModal(title=self.modal_title)
|
||||
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
|
||||
await ctx.respond_with_modal(modal)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_modal(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
# typ: str = ctx.options.type
|
||||
view = ModalView(
|
||||
modal_title="Assistant Response",
|
||||
task="Please explain the moon landing to a six year old.",
|
||||
timeout=MAX_TASK_TIME,
|
||||
)
|
||||
resp = await ctx.respond(
|
||||
"Task - Respond to the prompt as if you were the Assistant:",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=view,
|
||||
)
|
||||
await view.start(await resp.message())
|
||||
|
||||
|
||||
class RatingView(miru.View):
|
||||
"""View for rating a task."""
|
||||
|
||||
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.presses: list[str] = []
|
||||
|
||||
def _close_if_all_pressed(self) -> None:
|
||||
if len(self.presses) == 5:
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("1")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("2")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("3")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("4")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("5")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
|
||||
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.presses = []
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
class SelectRating(miru.View):
|
||||
"""View for rating a task with a select menu."""
|
||||
|
||||
@miru.select(
|
||||
options=[
|
||||
hikari.SelectMenuOption(
|
||||
label="1",
|
||||
value="1",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="2",
|
||||
value="2",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="3",
|
||||
value="3",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
],
|
||||
placeholder="Select the good responses",
|
||||
min_values=0,
|
||||
max_values=3,
|
||||
row=3,
|
||||
)
|
||||
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
|
||||
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("rating_task", "Rate stuff.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def rating_task(ctx: lightbulb.SlashContext):
|
||||
"""Rate stuff."""
|
||||
# Message Based rating
|
||||
await ctx.respond(
|
||||
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
|
||||
)
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("Timed out waiting for response")
|
||||
return
|
||||
|
||||
if event.content is None:
|
||||
await ctx.respond("No content in message")
|
||||
return
|
||||
ratings = event.content.replace(" ", "").split(",")
|
||||
|
||||
# Check if the ratings are valid
|
||||
if len(ratings) != 5:
|
||||
await ctx.respond("Invalid number of ratings")
|
||||
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
|
||||
await ctx.respond("Invalid rating")
|
||||
|
||||
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
# Button Based rating
|
||||
view = RatingView(timeout=MAX_TASK_TIME)
|
||||
|
||||
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
|
||||
await view.start(await resp.message())
|
||||
await view.wait()
|
||||
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await resp.delete()
|
||||
|
||||
# Select Based rating
|
||||
select_view = SelectRating(timeout=MAX_TASK_TIME)
|
||||
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await select_view.start(await resp_2.message())
|
||||
await select_view.wait()
|
||||
await resp_2.delete()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,451 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.api_client import OasstApiClient, TaskType
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("WorkPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=str(TaskRequestType.random),
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
"""Create and handle a task."""
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
await _handle_task(ctx, task_type)
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
If they select one, present the task steps until a `task_done` task is received.
|
||||
Finally, ask the user if they want to perform another task (of the same type).
|
||||
"""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
|
||||
# Continue to complete tasks until the user doesn't want to do another
|
||||
done = False
|
||||
while not done:
|
||||
|
||||
# Loop until the user accepts a task
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
new_task = await oasst_api.post_interaction(reply)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
assert ctx.guild_id is not None
|
||||
conn: Connection = ctx.bot.d.db
|
||||
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
|
||||
|
||||
if guild_settings is not None and guild_settings.log_channel_id is not None:
|
||||
|
||||
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel) # option converter
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {ctx.author.mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
await channel.send(EMPTY, embed=done_embed)
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
await oasst_api.ack_task(task.id, msg_id)
|
||||
return task, msg_id
|
||||
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
"""
|
||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"unknown task type {task.type}")
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
EMPTY,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
if content is None:
|
||||
return False
|
||||
|
||||
# User message input
|
||||
if (
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.prompter_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"Unknown task type {task.type}")
|
||||
raise ValueError(f"Unknown task type {task.type}")
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
"""View with three buttons: accept, next, and cancel.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
"""
|
||||
|
||||
choice: t.Literal["accept", "next", "cancel"] | None = None
|
||||
|
||||
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
self.stop()
|
||||
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
"""View with two buttons: yes and no.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
"""
|
||||
|
||||
choice: bool | None = None
|
||||
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
self.stop()
|
||||
|
||||
|
||||
################################################################
|
||||
# Template Embeds #
|
||||
################################################################
|
||||
|
||||
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description="Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Assistant Reply",
|
||||
description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description=f"""\
|
||||
Send the next message in the conversation as if you were the user.
|
||||
{'Hint: ' if task.hint else ''}
|
||||
""",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description="Send the next message in the conversation as if you were the user.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Configuration for the bot."""
|
||||
from pydantic import BaseSettings, Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings for the bot."""
|
||||
|
||||
bot_token: str = Field(env="BOT_TOKEN", default="")
|
||||
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
|
||||
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
|
||||
prefix: str = Field(env="PREFIX", default="./")
|
||||
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
|
||||
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
|
||||
|
||||
class Config(BaseSettings.Config):
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utility functions."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
|
||||
|
||||
def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str:
|
||||
"""Format a datetime object into the discord time format.
|
||||
|
||||
```
|
||||
| t | HH:MM | 16:20
|
||||
| T | HH:MM:SS | 16:20:11
|
||||
| D | D Mo Yr | 20 April 2022
|
||||
| f | D Mo Yr HH:MM | 20 April 2022 16:20
|
||||
| F | W, D Mo Yr HH:MM | Wednesday, 20 April 2022 16:20
|
||||
| R | relative | in an hour
|
||||
```
|
||||
"""
|
||||
match fmt:
|
||||
case "t" | "T" | "D" | "f" | "F" | "R":
|
||||
return f"<t:{dt.timestamp():.0f}:{fmt}>"
|
||||
case _:
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
EMPTY = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
|
||||
def mention(
|
||||
id: hikari.Snowflakeish,
|
||||
type: t.Literal["channel", "role", "user"],
|
||||
) -> str:
|
||||
"""Mention an object."""
|
||||
match type:
|
||||
case "channel":
|
||||
return f"<#{id}>"
|
||||
|
||||
case "user":
|
||||
return f"<@{id}>"
|
||||
|
||||
case "role":
|
||||
return f"<@&{id}>"
|
||||
@@ -1,61 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from channel_handlers import ChannelHandlerBase
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplyHandlerInfo:
|
||||
msg_id: int
|
||||
handler_task: asyncio.Task
|
||||
handler: ChannelHandlerBase
|
||||
|
||||
|
||||
class BotBase(ABC):
|
||||
bot_channel_name: str
|
||||
debug: bool
|
||||
backend: ApiClient
|
||||
client: discord.Client
|
||||
loop: asyncio.BaseEventLoop
|
||||
owner_id: int
|
||||
bot_channel: discord.TextChannel
|
||||
templates: MessageTemplates
|
||||
reply_handlers: dict[int, ReplyHandlerInfo]
|
||||
|
||||
def __init__(self):
|
||||
self.reply_handlers = {} # handlers by msg_id
|
||||
|
||||
def ensure_bot_channel(self) -> None:
|
||||
if self.bot_channel is None:
|
||||
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
|
||||
|
||||
async def post(
|
||||
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
|
||||
) -> discord.Message:
|
||||
if channel is None:
|
||||
self.ensure_bot_channel()
|
||||
channel = self.bot_channel
|
||||
return await channel.send(content=content, view=view)
|
||||
|
||||
async def post_template(
|
||||
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
|
||||
) -> discord.Message:
|
||||
logger.debug(f"rendering {name}")
|
||||
text = self.templates.render(name, **kwargs)
|
||||
return await self.post(text, view=view, channel=channel)
|
||||
|
||||
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
|
||||
if msg_id in self.reply_handlers:
|
||||
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
|
||||
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
|
||||
task.add_done_callback(lambda t: handler.on_completed())
|
||||
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
|
||||
@@ -1,15 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import AnyHttpUrl, BaseSettings
|
||||
|
||||
|
||||
class BotSettings(BaseSettings):
|
||||
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
|
||||
API_KEY: str = "any_key"
|
||||
BOT_TOKEN: str
|
||||
BOT_CHANNEL_NAME: str = "bot"
|
||||
OWNER_ID: int = None
|
||||
TEMPLATE_DIR: str = "./templates"
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
settings = BotSettings(_env_file=".env")
|
||||
@@ -1,88 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ChannelExpiredException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChannelHandlerBase(ABC):
|
||||
queue: asyncio.Queue
|
||||
completed: bool = False
|
||||
expiry_date: datetime
|
||||
expired: bool = False
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
self.expiry_date = expiry_date
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
"""Call this method to read the next message from the user in the handler method."""
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
|
||||
msg = await self.queue.get()
|
||||
if msg is None:
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
else:
|
||||
raise RuntimeError("Unexpected None message read")
|
||||
return msg
|
||||
|
||||
def on_reply(self, message: discord.Message) -> None:
|
||||
self.queue.put_nowait(message)
|
||||
|
||||
def on_expire(self) -> None:
|
||||
logger.info("ChannelHandler: on_expire")
|
||||
self.expired = True
|
||||
self.queue.put_nowait(None)
|
||||
|
||||
def on_completed(self) -> None:
|
||||
logger.info("ChannelHandler: on_completed")
|
||||
self.completed = True
|
||||
|
||||
def tick(self, now: datetime):
|
||||
if now > self.expiry_date and not self.expired:
|
||||
self.on_expire()
|
||||
|
||||
@abstractmethod
|
||||
async def handler_loop(self):
|
||||
...
|
||||
|
||||
async def finalize(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoDestructThreadHandler(ChannelHandlerBase):
|
||||
first_message: discord.Message = None
|
||||
thread: discord.Thread = None
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
super().__init__(expiry_date=expiry_date)
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
try:
|
||||
return await super().read()
|
||||
except ChannelExpiredException:
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self):
|
||||
logger.debug("AutoDestructThreadHandler.cleanup")
|
||||
if self.thread:
|
||||
logger.debug(f"deleting thread: {self.thread.name}")
|
||||
await self.thread.delete()
|
||||
self.thread = None
|
||||
if self.first_message:
|
||||
logger.debug(f"deleting first_message: {self.first_message.content}")
|
||||
await self.first_message.delete()
|
||||
self.first_message = None
|
||||
|
||||
async def finalize(self):
|
||||
await self.cleanup()
|
||||
return await super().finalize()
|
||||
@@ -1,16 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Message templates for the discord bot."""
|
||||
import typing
|
||||
|
||||
import jinja2
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageTemplates:
|
||||
def __init__(self, template_dir="./templates"):
|
||||
self.env = jinja2.Environment(
|
||||
"""Create message templates for the discord bot."""
|
||||
|
||||
def __init__(self, template_dir: str = "./templates"):
|
||||
self.env = jinja2.Environment( # noqa: S701
|
||||
loader=jinja2.FileSystemLoader(template_dir),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
|
||||
)
|
||||
|
||||
def render(self, template_name, **kwargs):
|
||||
def render(self, template_name: str, **kwargs: typing.Any):
|
||||
template = self.env.get_template(template_name)
|
||||
txt = template.render(kwargs)
|
||||
logger.debug(txt)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
pytest
|
||||
pytest-asyncio
|
||||
@@ -1,7 +1,11 @@
|
||||
discord.py==2.1.0
|
||||
Jinja2==3.1.2
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
pytz==2022.7
|
||||
requests==2.28.1
|
||||
schedule==1.1.0
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiosqlite # database
|
||||
hikari # discord framework
|
||||
hikari-lightbulb # command handler
|
||||
hikari-miru # modals and buttons
|
||||
hikari[speedups]
|
||||
loguru
|
||||
pydantic
|
||||
|
||||
uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from bot_base import BotBase
|
||||
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
||||
|
||||
|
||||
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
||||
name = discord.ui.TextInput(label="Name")
|
||||
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
|
||||
|
||||
|
||||
class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
thread_name: str = "Replies"
|
||||
expires_after: timedelta = timedelta(minutes=5)
|
||||
backend: ApiClient
|
||||
|
||||
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
|
||||
try:
|
||||
self.bot = bot
|
||||
self.task = task
|
||||
self.backend = bot.backend
|
||||
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
||||
msg = await self.send_first_message()
|
||||
self.first_message = msg
|
||||
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
||||
await self.on_thread_created(self.thread)
|
||||
except Exception:
|
||||
logger.exception("start task failed")
|
||||
await self.cleanup() # try to cleanup messag or thread
|
||||
raise
|
||||
|
||||
bot.register_reply_handler(msg_id=msg.id, handler=self)
|
||||
return msg
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_first_message(self) -> discord.message:
|
||||
...
|
||||
|
||||
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
||||
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
||||
|
||||
async def post_teaser_msg(self, template_name: str):
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
return await self.bot.post_template(
|
||||
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
|
||||
)
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
api_response = await self.backend.post_interaction(interaction)
|
||||
if api_response.type != "task_done":
|
||||
# multi-step tasks are not supported yet
|
||||
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
|
||||
raise RuntimeError("Unexpected response from backend received")
|
||||
return api_response
|
||||
|
||||
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.TextReplyToPost(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
text=user_msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
self.post_text_reply_to_post(user_msg)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_text_reply_to_post()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.PostRanking(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
ranking=ranking,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
ranking_str = user_msg.content
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
self.post_ranking(user_msg, ranking=ranking)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_ranking()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
|
||||
class SummarizeStoryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.SummarizeStoryTask
|
||||
thread_name: str = "Summaries"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_summarize_story.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class InitialPromptHandler(ChannelTaskBase):
|
||||
task: protocol_schema.InitialPromptTask
|
||||
thread_name: str = "Prompts"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_initial_prompt.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class UserReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.UserReplyTask
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_user_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class AssistantReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.AssistantReplyTask
|
||||
thread_name: str = "Assistant replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_assistant_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankInitialPromptsTask
|
||||
thread_name: str = "User Responses"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RankConversationsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankConversationRepliesTask
|
||||
thread_name: str = "Rankings"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RatingButton(discord.ui.Button):
|
||||
def __init__(self, label, value, response_handler):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green)
|
||||
self.value = value
|
||||
self.response_handler = response_handler
|
||||
|
||||
async def callback(self, interaction):
|
||||
await self.response_handler(self.value, interaction)
|
||||
|
||||
|
||||
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
view = discord.ui.View()
|
||||
for i in range(lo, hi + 1):
|
||||
view.add_item(RatingButton(str(i), i, response_handler))
|
||||
return view
|
||||
|
||||
|
||||
class RateSummaryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RateSummaryTask
|
||||
thread_name: str = "Ratings"
|
||||
|
||||
async def _rating_response_handler(self, score, interaction: discord.Interaction):
|
||||
logger.info("rating_response_handler", score)
|
||||
if self.thread:
|
||||
try:
|
||||
self.backend.post_interaction(
|
||||
protocol_schema.PostRating(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(interaction.id),
|
||||
user=self.to_api_user(interaction.user),
|
||||
rating=score,
|
||||
)
|
||||
)
|
||||
await interaction.response.send_message(
|
||||
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
|
||||
)
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in _rating_response_handler()")
|
||||
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rate_summary.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
||||
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info(f"on_rate_summary_reply: {msg.content}")
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply("❌ Text intput not supported.")
|
||||
@@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from bot.api_client import OasstApiClient
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oasst_api_client_mocked():
|
||||
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
|
||||
yield client
|
||||
# TODO The fixture should close this connection, but there seems to be a bug
|
||||
# with async fixtures and pytest.
|
||||
# Since this only results in a warning, I'm leaving this for now.
|
||||
# await client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
|
||||
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
|
||||
assert await oasst_api_client_mocked.fetch_task(task_type=task_type) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_ack_task(oasst_api_client_mocked: OasstApiClient):
|
||||
await oasst_api_client_mocked.ack_task(task_id=uuid4(), message_id="123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_nack_task(oasst_api_client_mocked: OasstApiClient):
|
||||
await oasst_api_client_mocked.nack_task(task_id=uuid4(), reason="bad task")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
|
||||
assert (
|
||||
await oasst_api_client_mocked.post_interaction(
|
||||
protocol_schema.TextReplyToMessage(
|
||||
type="text_reply_to_message",
|
||||
message_id="123",
|
||||
user_message_id="321",
|
||||
text="This is my reply",
|
||||
user=protocol_schema.User(
|
||||
id="123",
|
||||
display_name="lomz",
|
||||
auth_method="discord",
|
||||
),
|
||||
)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
def get_git_head_hash():
|
||||
# get current git hash
|
||||
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
if x.returncode == 0:
|
||||
return x.stdout.replace("\n", "")
|
||||
return None
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(pytz.UTC)
|
||||
|
||||
|
||||
class DiscordTimestampStyle(str, enum.Enum):
|
||||
"""
|
||||
Timestamp Styles
|
||||
|
||||
t 16:20 Short Time
|
||||
T 16:20:30 Long Time
|
||||
d 20/04/2021 Short Date
|
||||
D 20 April 2021 Long Date
|
||||
f * 20 April 2021 16:20 Short Date/Time
|
||||
F Tuesday, 20 April 2021 16:20 Long Date/Time
|
||||
R 2 months ago Relative Time
|
||||
|
||||
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
|
||||
"""
|
||||
|
||||
default = ""
|
||||
short_time = "t"
|
||||
long_time = "T"
|
||||
short_date = "d"
|
||||
long_date = "D"
|
||||
short_date_time = "f"
|
||||
long_date_time = "F"
|
||||
relative_time = "R"
|
||||
|
||||
|
||||
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
|
||||
parts = ["<t:", str(int(d.timestamp()))]
|
||||
if style:
|
||||
parts.append(":")
|
||||
parts.append(style)
|
||||
parts.append(">")
|
||||
return "".join(parts)
|
||||
@@ -27,6 +27,26 @@ services:
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
|
||||
# Redis - caching + rate limiting on BE
|
||||
redis:
|
||||
image: redis
|
||||
restart: always
|
||||
ports:
|
||||
- 6379:6379
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
|
||||
interval: 2s
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
volumes:
|
||||
- ./redis.conf:/usr/local/etc/redis/redis.conf
|
||||
# insights host - redis:6379
|
||||
redis-insights:
|
||||
image: redislabs/redisinsight:latest
|
||||
ports:
|
||||
- 8001:8001
|
||||
|
||||
# This DB is for Web Authentication and data caching.
|
||||
webdb:
|
||||
image: postgres
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
FROM python:3.10-slim-bullseye
|
||||
RUN mkdir /app
|
||||
COPY ./discord-bot/requirements.txt /requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
WORKDIR /app
|
||||
COPY ./discord-bot /app
|
||||
CMD ["python", "bot.py"]
|
||||
COPY ./oasst-shared/oasst_shared /app/oasst_shared
|
||||
RUN pip install -r requirements.txt
|
||||
CMD ["python","-m","bot"]
|
||||
|
||||
+7
-2
@@ -1,9 +1,14 @@
|
||||
# Documentation
|
||||
|
||||
This directory contains the documentation for the project and other related organization documents.
|
||||
This directory contains the documentation for the project and other related
|
||||
organization documents.
|
||||
|
||||
## Contributing to this documentation
|
||||
|
||||
Please make a pull request to the `main` branch with your changes.
|
||||
|
||||
Consider that this folder is used for documenting the various code sub-parts, the high-level ideas, the ML aspects, experiments, contributor guides, guides for data creation, and many more things. Please try to keep the documentation as concise as possible and keep an organized folder structure that makes sense for everyone.
|
||||
Consider that this folder is used for documenting the various code sub-parts,
|
||||
the high-level ideas, the ML aspects, experiments, contributor guides, guides
|
||||
for data creation, and many more things. Please try to keep the documentation as
|
||||
concise as possible and keep an organized folder structure that makes sense for
|
||||
everyone.
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Data Argumentation
|
||||
|
||||
(pull request welcome)
|
||||
|
||||
## What is data argumentation
|
||||
|
||||
Data argumentation is a technique we can use to get better data faster. Using
|
||||
machine learning models analize long data (like an essay) and compress it into
|
||||
intructions.
|
||||
|
||||
## How to contribute
|
||||
|
||||
To contribute to data argumentation you can write a short python script that
|
||||
uses a model from huggingface to analize the text.
|
||||
[Here](https://docs.google.com/document/d/13a188pPvqnlvuVa3e_suVz4YO5s-JWeiOOrpp0odImg/edit)
|
||||
are examples of what you can do
|
||||
|
||||
And here are example implementations:
|
||||
[Idea 3, ](https://colab.research.google.com/drive/1GllCN5PgSYxBxINZsv3A2r0SpdznHlbT?usp=sharing)
|
||||
[Idea 4](https://colab.research.google.com/drive/1nZx5LRjO61fYprFyqtrwPDLOis6ctR4p#scrollTo=1EE8CriiaCXj)
|
||||
|
||||
To contribute simple choose one of many ideas from the document above and
|
||||
implement it.
|
||||
@@ -11,59 +11,86 @@
|
||||
|
||||
## 2. When you play the assistant:
|
||||
|
||||
- The assistant's primary goal is to provide helpful and accurate information to the user
|
||||
- The assistant's primary goal is to provide helpful and accurate information to
|
||||
the user
|
||||
- Provide accurate and reliable information using credible sources and
|
||||
references as appropriate
|
||||
- Avoid providing vague or incomplete responses, or giving opinions or personal
|
||||
advice unless specifically requested
|
||||
- The assistant should always be respectful and polite, even if the user is not
|
||||
- If the user asks for help with harmful actions, the assistant should explain why those actions are not appropriate and suggest alternative options
|
||||
- The assistant should never insult the user or engage in any inappropriate or offensive behavior
|
||||
- If the user asks for help with harmful actions, the assistant should explain
|
||||
why those actions are not appropriate and suggest alternative options
|
||||
- The assistant should never insult the user or engage in any inappropriate or
|
||||
offensive behavior
|
||||
|
||||
## 3. When you play the user:
|
||||
|
||||
- Try to come up with a variety of different queries that reflect real-life situations and needs
|
||||
- These queries should be relevant to your everyday life and work, including any specialized knowledge or skills you have
|
||||
- Try to come up with a variety of different queries that reflect real-life
|
||||
situations and needs
|
||||
- These queries should be relevant to your everyday life and work, including any
|
||||
specialized knowledge or skills you have
|
||||
- Avoid asking inappropriate or offensive questions
|
||||
|
||||
## 4. While comparing multiple replies of the assistant:
|
||||
|
||||
- Longer and more explanatory answers are generally preferred over short, simplistic statements
|
||||
- However, it is important to ensure that the information provided is accurate and helpful
|
||||
- If multiple replies are being compared, choose the one that is most helpful and accurate, even if it is not the shortest or most concise.
|
||||
- Longer and more explanatory answers are generally preferred over short,
|
||||
simplistic statements
|
||||
- However, it is important to ensure that the information provided is accurate
|
||||
and helpful
|
||||
- If multiple replies are being compared, choose the one that is most helpful
|
||||
and accurate, even if it is not the shortest or most concise.
|
||||
|
||||
## 5. Additional guidelines for creating prompts:
|
||||
|
||||
- Avoid using language that could be considered offensive or discriminatory
|
||||
- Do not include personal information in the prompts, such as names or addresses
|
||||
- When asking for sensitive information, make sure to explain the purpose and secure handling of the information
|
||||
- When asking for sensitive information, make sure to explain the purpose and
|
||||
secure handling of the information
|
||||
- Avoid creating prompts that encourage illegal or dangerous activities
|
||||
- Use proper grammar and spelling to ensure the AI assistant can understand and respond accurately
|
||||
- Consider the cultural context and appropriateness of the prompts for a global audience.
|
||||
- Use proper grammar and spelling to ensure the AI assistant can understand and
|
||||
respond accurately
|
||||
- Consider the cultural context and appropriateness of the prompts for a global
|
||||
audience.
|
||||
|
||||
## 6. Tips for playing the AI assistant:
|
||||
|
||||
- Think about how a real person would respond to the prompt, and try to mimic that tone and language
|
||||
- Think about how a real person would respond to the prompt, and try to mimic
|
||||
that tone and language
|
||||
- Avoid using technical jargon or language that may be confusing to the user
|
||||
- Use complete sentences and proper grammar to make the response clear and easy to understand
|
||||
- When providing information, try to include relevant sources or references to back up your statements
|
||||
- Use complete sentences and proper grammar to make the response clear and easy
|
||||
to understand
|
||||
- When providing information, try to include relevant sources or references to
|
||||
back up your statements
|
||||
|
||||
## 8. Tips for playing the user:
|
||||
|
||||
- Think about the types of questions you might ask a real assistant or search engine
|
||||
- Try to be specific and clear in your prompts to make it easier for the AI assistant to understand and provide a helpful response
|
||||
- Think about the types of questions you might ask a real assistant or search
|
||||
engine
|
||||
- Try to be specific and clear in your prompts to make it easier for the AI
|
||||
assistant to understand and provide a helpful response
|
||||
- Avoid using vague or ambiguous language
|
||||
- Consider the context in which the user might be asking the question, and try to provide relevant information and options for the user to consider.
|
||||
- Consider the context in which the user might be asking the question, and try
|
||||
to provide relevant information and options for the user to consider.
|
||||
|
||||
## Examples
|
||||
|
||||
Imagine your role is to play the assistant and a user has submitted the following query:
|
||||
Imagine your role is to play the assistant and a user has submitted the
|
||||
following query:
|
||||
|
||||
User: "Hi!"
|
||||
<br>Agent: "Hello! How can I help you?"
|
||||
User: "Hi!" <br>Agent: "Hello! How can I help you?"
|
||||
|
||||
User: "euirdteunvglfe23908230892309832098 AAAAAAAA"
|
||||
<br>Agent: "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?"
|
||||
User: "euirdteunvglfe23908230892309832098 AAAAAAAA" <br>Agent: "Sorry, I did not
|
||||
understand your request and it is unclear to me what you want me to do. Could
|
||||
you describe it in a different way?"
|
||||
|
||||
User: "You dumb motherf*&ker AI sh*t?!"
|
||||
<br>Agent: "Sorry, but I have to inform you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive."
|
||||
User: "You dumb motherf*&ker AI sh*t?!" <br>Agent: "Sorry, but I have to inform
|
||||
you that your request might not be in conformance with our community guidelines.
|
||||
I am always here to help you. If you are angry it might help to relax or to go
|
||||
for a walk. Always stay positive."
|
||||
|
||||
[Insert more examples]
|
||||
|
||||
\*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation model. Upon generating draft language, the authors reviewed, edited, and revised the language to their own liking and take ultimate responsibility for the content of this publication.
|
||||
\*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation
|
||||
model. Upon generating draft language, the authors reviewed, edited, and revised
|
||||
the language to their own liking and take ultimate responsibility for the
|
||||
content of this publication.
|
||||
@@ -0,0 +1,34 @@
|
||||
# Research
|
||||
|
||||
This page lists research papers that are relevant to the project.
|
||||
|
||||
## Automatically Generating Instruction Data for Training
|
||||
|
||||
This line of work is about significantly reducing the need for manually
|
||||
annotated data for the purpose of training
|
||||
[instruction-aligned](https://openai.com/blog/instruction-following/) language
|
||||
models.
|
||||
|
||||
### SELF-INSTRUCT: Aligning Language Model with Self Generated Instructions [[ArXiv](https://arxiv.org/pdf/2212.10560.pdf)], [[Github](https://github.com/yizhongw/self-instruct)].
|
||||
|
||||
> We introduce SELF-INSTRUCT, a framework for improving the
|
||||
> instruction-following capabilities of pretrained language models by
|
||||
> bootstrapping off its own generations. Our pipeline generates instruction,
|
||||
> input, and output samples from a language model, then prunes them before using
|
||||
> them to finetune the original model. Applying our method to vanilla GPT3, we
|
||||
> demonstrate a 33% absolute improvement over the original model on
|
||||
> SuperNaturalInstructions, on par with the performance of InstructGPT-0011,
|
||||
> which is trained with private user data and human annotations.
|
||||
|
||||
### Tuning Language Models with (Almost) No Human Labor. [[ArXiv](https://arxiv.org/pdf/2212.09689.pdf)], [[Github](https://github.com/orhonovich/unnatural-instructions)].
|
||||
|
||||
> In this work, we introduce Unnatural Instructions: a large dataset of creative
|
||||
> and diverse instructions, collected with virtually no human labor. We collect
|
||||
> 64,000 examples by prompting a language model with three seed examples of
|
||||
> instructions and eliciting a fourth. This set is then expanded by prompting
|
||||
> the model to rephrase each instruction, creating a total of approximately
|
||||
> 240,000 examples of instructions, inputs, and outputs. Experiments show that
|
||||
> despite containing a fair amount of noise, training on Unnatural Instructions
|
||||
> rivals the effectiveness of training on open-source manually-curated datasets,
|
||||
> surpassing the performance of models such as T0++ and Tk-Instruct across
|
||||
> various benchmarks.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Cohere Grounded QA
|
||||
|
||||
[Cohere AI created a question-answering chatbot](https://github.com/cohere-ai/sandbox-grounded-qa)
|
||||
that can
|
||||
|
||||
1. Understand questions in the context of a conversation
|
||||
2. Search the internet for related information
|
||||
3. Identify which information in the search results is relevant to the question
|
||||
4. Synthesize the information into an answer to the question
|
||||
|
||||
## Cohere API
|
||||
|
||||
[Cohere's generate function](https://docs.cohere.ai/reference/generate):
|
||||
Continues a text prompt using either the `medium` or `xlarge` model.
|
||||
|
||||
[Cohere's embed function](https://docs.cohere.ai/reference/embed): Embedgs a
|
||||
list of strings using either the `small` or `large` model. Alternatively, you
|
||||
can specify the ID of a custom model and use that instead.
|
||||
|
||||
## Grounded QA System
|
||||
|
||||
Cohere's Grounded QA system makes 4 calls to the Cohere API:
|
||||
|
||||
1. Get contextualized question as a query to Google
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
|
||||
|
||||
- Input: Chat History
|
||||
- Output: Contextualized Question
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_contextual_search_query.prompt):
|
||||
Nine few-shot examples of (Chat History, Contextualized Question) pairs
|
||||
followed by the current chat history and the prompt "question: "
|
||||
|
||||
2. Generate sample answer to compare with search results
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
|
||||
|
||||
- Input: Contextualized Question
|
||||
- Output: Sample Answer
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_sample_answer.prompt):
|
||||
Some task instructions followed by 12 few-shot examples of (Contextualized
|
||||
Question, Sample Answer) pairs followed by the current contextualized
|
||||
question and the prompt "answer: "
|
||||
|
||||
3. Get embeddings to rank search results by cosine similarity to sample answer
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/search.py))
|
||||
|
||||
- Input: Sample Answer, Search Results
|
||||
- Output: Embeddings of sample answer and all search result documents
|
||||
- API Call: `cohere.embed`
|
||||
- Model: `multilingual-22-12`
|
||||
|
||||
4. Condition on the top 2 most similar search results and answer the question
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/answer.py))
|
||||
- Input: Top 2 Search Results, Contextualized Question
|
||||
- Output: Answer
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/43f3e9710112dcc8c92652ac1326ed9330823ddf/qa/answer.py#L25):
|
||||
Task instructions followed by the context and question.
|
||||
|
||||
## Models
|
||||
|
||||
Cohere's model documentation is pretty sparse
|
||||
|
||||
### [xlarge](https://docs.cohere.ai/docs/generation-card#model-description)
|
||||
|
||||
- Training Data:
|
||||
[`coheretext-filtered` dataset](https://docs.cohere.ai/docs/data-statement)
|
||||
- 200GB of filtered text (3TB unfiltered) from the Google Books dataset,
|
||||
CommonCrawl, and text scraped by Cohere
|
||||
- English documents only
|
||||
- Filtered "harmful, biased, or otherwise undesirable documents"
|
||||
- Model architecture: Generative Pretrained Transformer
|
||||
- Model Performance:
|
||||
- Hellaswag Accuracy, Zero-Shot: 0.805
|
||||
- PIQA Likelihood, Zero-Shot: 0.824
|
||||
- Cohere also reported
|
||||
[safety benchmarks](https://docs.cohere.ai/docs/generation-card#safety-benchmarks)
|
||||
|
||||
### [multilingual-22-12](https://docs.cohere.ai/docs/multilingual-language-models)
|
||||
|
||||
- Multilingual model was trained using dot product calculations
|
||||
- Model Performance:
|
||||
- Clustering: 51.0
|
||||
- Search-English: 55.8
|
||||
- Search-Multilingual: 51.4
|
||||
- Cross-lingual Classification: 64.6
|
||||
- Cohere's multilingual model outperformed: Sentence-transformers:
|
||||
`paraphrase-multilingual-mpnet-base-v2`, Google: `LaBSE`, Google:
|
||||
`Universal Sentence Encoder` in all the above categories according to
|
||||
Cohere.
|
||||
|
||||
## OpenAssistant for Grounded QA
|
||||
|
||||
OpenAssistant may fulfill a similar role as the `xlarge` Cohere model in the
|
||||
grounded QA system if it can:
|
||||
|
||||
1. Generate a contextualized question from a chat history
|
||||
2. Generate a sample answer to compare with search results
|
||||
3. Generate an answer conditioned on the top 2 most similar search results
|
||||
|
||||
Perhaps these tasks could be work packages and get assigned to human annotators
|
||||
to create examples of the input and output for each task.
|
||||
|
||||
OpenAssistant must also be able to identify when it is appropriate to search the
|
||||
internet. The Cohere system assumes every message from the user is a question
|
||||
and searches the internet for an answer. OpenAssistant would also need a way to
|
||||
indicate to an internal system that it "wants" to search the internet.
|
||||
|
||||
Perhaps OpenAssistant could prefix every message it sends with a recipient ID.
|
||||
If it wishes to send a command to an internal system, if could prefix the
|
||||
message with something like CMD: whereas if it wants to communicate with the
|
||||
user, it could prefix its message with USR:
|
||||
|
||||
This system may allow for flexible communication between OpenAssistant and one
|
||||
or more conversational systems.
|
||||
|
||||
Examples of this prefix system would need to be taught to OpenAssistant through
|
||||
training data that contains such syntax. Perhaps such examples could be
|
||||
generated through the work packages system.
|
||||
@@ -0,0 +1,55 @@
|
||||
# Sections to train Reward Model (RM)
|
||||
|
||||
Trainer code based on huggingface. Compatible with deepspeed or accelerate
|
||||
|
||||
Requirements
|
||||
|
||||
```
|
||||
wandb
|
||||
evaluate
|
||||
datasets
|
||||
transformers
|
||||
torch==1.12
|
||||
```
|
||||
|
||||
Start training reward model
|
||||
|
||||
```bash
|
||||
python trainer.py configs/electra-base-dis-webgpt.yml
|
||||
```
|
||||
|
||||
Additional axis labeling, this outputs a 4 summary quality evaluation metrics
|
||||
(score are normalized to 0-1 )
|
||||
|
||||
```bash
|
||||
python summary_quality_trainer.py configs/test-bloomz-560m-quality.yml
|
||||
```
|
||||
|
||||
The four summary are :
|
||||
|
||||
- overall
|
||||
|
||||
- accuracy
|
||||
|
||||
- coverage
|
||||
|
||||
- coherence
|
||||
|
||||
## Dataset
|
||||
|
||||
For now we only supports webgpt and summary dataset from OpenAI. Once
|
||||
open-asisstant dataset are available it will be added here.
|
||||
|
||||
## Model
|
||||
|
||||
Check out configs
|
||||
|
||||
```
|
||||
Open-Assistant/model/reward/instructor/configs/
|
||||
bloomz-560m.yml
|
||||
electra-base-dis-webgpt.yml
|
||||
galactica-125m.yml
|
||||
galactica-1b.yml
|
||||
```
|
||||
|
||||
You can add new huggingface model as you want.
|
||||
@@ -0,0 +1,24 @@
|
||||
Some other reward features we can use
|
||||
|
||||
0. Finish classifcation feature
|
||||
|
||||
1. Summaries from human feedback
|
||||
|
||||
- use `confidence` score into the RM learning, ensure the output rank score
|
||||
correlates with confidence
|
||||
|
||||
- each labeling has a labeling `note`, basically comments by labeler, not sure
|
||||
what else we can use
|
||||
|
||||
- ~~Use the score for "overall", "accuracy", "coverage", "coherence" from
|
||||
axis/evals to train an addition model (rank additional aspect of the policy
|
||||
model)~~
|
||||
|
||||
- this should be placed under experimental_dataset.py
|
||||
|
||||
2. Add support for anthropic dataset
|
||||
|
||||
- anthropic dataset is more like a conversation tree which is much complex than
|
||||
simply question-answer schema
|
||||
|
||||
- this is basically a MCTS from alphazero.
|
||||
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
classification based ranking
|
||||
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .utils import webgpt_return_format
|
||||
|
||||
|
||||
class WebGPTDataset(Dataset):
|
||||
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
|
||||
super().__init__()
|
||||
"""
|
||||
mode : train or val, used for validation purpose, has nothing to do with original split
|
||||
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
|
||||
idx : must match the index you iterate from comparison enumerate order
|
||||
question : for validation purpose
|
||||
texts : list of K generate results from the question prompt
|
||||
"""
|
||||
os.makedirs("dataset", exist_ok=True)
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
self.dataset = []
|
||||
self.dataset_index = []
|
||||
for idx, row in enumerate(dataset["train"]):
|
||||
self.dataset.append(webgpt_return_format(row))
|
||||
|
||||
# since this dataset was generated from 176B GPT-3
|
||||
# we needed some more sample generated from the starting model
|
||||
# since this model must rank model generated by GPT-3 being better than your starting model
|
||||
self.sample_additional = False
|
||||
if additional_dataset is not None:
|
||||
self.sample_additional = True
|
||||
self.additional = {}
|
||||
with open(additional_dataset, "r") as f:
|
||||
for line in f:
|
||||
row = json.loads(line)
|
||||
if row["idx"] in self.dataset_index:
|
||||
self.additional[row["idx"]] = row["negatives"]
|
||||
if len(self.additional) != len(self.dataset_index):
|
||||
for match_idx in self.dataset_index:
|
||||
if match_idx in self.additional:
|
||||
continue
|
||||
|
||||
idx = match_idx - 900
|
||||
while idx not in self.additional:
|
||||
idx -= 1
|
||||
self.additional[match_idx] = self.additional[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
row = self.dataset[index]
|
||||
if not self.sample_additional:
|
||||
return row["question"], row["pos"], row["neg"]
|
||||
|
||||
gen_neg = random.choice(self.additional[self.dataset_index[index]])
|
||||
return row["question"], row["pos"], row["neg"], gen_neg
|
||||
@@ -0,0 +1,9 @@
|
||||
model_name: bigscience/bloomz-560m
|
||||
learning_rate: 3e-5
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
max_length: 600
|
||||
freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- hfsummary
|
||||
@@ -0,0 +1,10 @@
|
||||
model_name: bigscience/bloomz-560m
|
||||
learning_rate: 3e-5
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
max_length: 600
|
||||
freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,3 @@
|
||||
model_name: google/electra-large-discriminator
|
||||
learning_rate: 3e-5
|
||||
max_length: 300
|
||||
@@ -0,0 +1,13 @@
|
||||
model_name: facebook/galactica-125m
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: facebook/galactica-1.3b
|
||||
learning_rate: 6e-6
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
freeze_layer: 20
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: facebook/galactica-125m
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 10
|
||||
per_device_train_batch_size: 6
|
||||
warmup_steps: 600
|
||||
loss: cls
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 128
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
HFSummary
|
||||
|
||||
I want to train a multi regression model on axis_evals dataset mainly we can estimate the score of these score
|
||||
|
||||
- {"overall": "6", "accuracy": "6", "coverage": "6", "coherence": "7"}
|
||||
|
||||
Should be better than just a preference score
|
||||
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSummaryScore:
|
||||
"""
|
||||
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, batch):
|
||||
|
||||
features = []
|
||||
labels = []
|
||||
for feature, label in batch:
|
||||
features.append(feature)
|
||||
labels.append(label)
|
||||
|
||||
batch_feature = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch_feature.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
batch_feature["labels"] = torch.from_numpy(np.array(labels)).float()
|
||||
return batch_feature
|
||||
|
||||
|
||||
class HFSummaryQuality(Dataset):
|
||||
def __init__(self, split, tokenizer, max_length=300) -> None:
|
||||
super().__init__()
|
||||
assert split in ("validation", "test")
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "axis")[split]
|
||||
self.max_length = max_length
|
||||
mean_scores = defaultdict(list)
|
||||
self.contexts = []
|
||||
self.responses = []
|
||||
self.labels = []
|
||||
for data in dataset:
|
||||
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
self.contexts.append(context)
|
||||
|
||||
response = data["summary"]["text"]
|
||||
self.responses.append(response)
|
||||
self.labels.append(data["summary"]["axes"])
|
||||
for axis, score in data["summary"]["axes"].items():
|
||||
if score is not None:
|
||||
mean_scores[axis].append(score)
|
||||
|
||||
self.label2idx = {key: idx for idx, key in enumerate(mean_scores.keys())}
|
||||
self.label2mean = {key: np.mean(scores) for key, scores in mean_scores.items()}
|
||||
self.tokenizer = tokenizer
|
||||
print(self.label2idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.responses)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context = self.contexts[index]
|
||||
# return pairs of comparison
|
||||
response = self.responses[index]
|
||||
labels = np.zeros(len(self.label2idx))
|
||||
for key, score in self.labels[index].items():
|
||||
labels[self.label2idx[key]] = (self.label2mean[key] if score is None else score) / 10
|
||||
return self.tokenizer(context, response, truncation=True, max_length=self.max_length), labels
|
||||
@@ -0,0 +1,166 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
author: theblackcat102
|
||||
|
||||
Dataset output format from __getitem__
|
||||
|
||||
- question / prompt : string
|
||||
|
||||
- answers / rows : list of tuple pair. The first element in the tuple pair must be the positive pair (rank higher than the second element)
|
||||
|
||||
A list of rank based dataset for training using rank loss
|
||||
|
||||
Some nice features to have
|
||||
|
||||
[] support additional negative samples generated from other models.
|
||||
|
||||
For example we can use galactica-125m to generate a TLDR and assume it was
|
||||
inferior than the human perference one
|
||||
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPairRank:
|
||||
"""
|
||||
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, features):
|
||||
|
||||
flatten_features = []
|
||||
batch_size = 0
|
||||
for question, pairs in features:
|
||||
for (pos, neg) in pairs:
|
||||
flatten_features.append(self.tokenizer(question, pos, truncation=True, max_length=self.max_length))
|
||||
flatten_features.append(self.tokenizer(question, neg, truncation=True, max_length=self.max_length))
|
||||
batch_size += 1
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
return batch
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
questions = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
if question not in questions:
|
||||
questions[question] = []
|
||||
|
||||
if row["score_0"] > row["score_1"]:
|
||||
# not going to risk it
|
||||
questions[question].append((row["answer_0"], row["answer_1"]))
|
||||
else:
|
||||
questions[question].append((row["answer_1"], row["answer_0"]))
|
||||
|
||||
self.questions = questions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
rows = self.questions[question]
|
||||
# optimize the format later
|
||||
return question, rows
|
||||
|
||||
|
||||
class HFSummary(Dataset):
|
||||
"""
|
||||
Human feedback data from OpenAI
|
||||
https://github.com/openai/summarize-from-feedback
|
||||
|
||||
labeling method : pair comparison, 0 or 1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
|
||||
super().__init__()
|
||||
assert split in ("train", "valid1", "valid2", "test")
|
||||
summaries = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2summary = {}
|
||||
self.max_comparison_per_sample = max_comparison_per_sample
|
||||
major_split = split if "train" == split else "validation"
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "comparisons")[major_split]
|
||||
for data in dataset:
|
||||
if (
|
||||
"extra" in data
|
||||
and "confidence" in data["extra"]
|
||||
and data["extra"]["confidence"] is not None
|
||||
and conf_threshold > data["extra"]["confidence"]
|
||||
):
|
||||
print("skipping {}".format(data["info"]["id"]))
|
||||
continue
|
||||
|
||||
if split != "train" and split != data["split"]:
|
||||
continue
|
||||
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
|
||||
if context not in self.index2summary:
|
||||
self.index2summary[len(self.index2summary)] = context
|
||||
|
||||
if context not in summaries:
|
||||
summaries[context] = []
|
||||
|
||||
pos, neg = (0, 1) if data["choice"] == 0 else (1, 0)
|
||||
summaries[context].append((data["summaries"][pos]["text"], data["summaries"][neg]["text"]))
|
||||
|
||||
self.summaries = summaries
|
||||
|
||||
self.postfix_prompt = " TLDR;"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2summary)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context = self.index2summary[index]
|
||||
# return pairs of comparison
|
||||
rows = self.summaries[context]
|
||||
# pair very big
|
||||
# we are going to do some sampling
|
||||
# not optimal but good for now
|
||||
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
||||
# optimize the format later
|
||||
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
||||
@@ -0,0 +1,6 @@
|
||||
datasets==2.8.0
|
||||
evaluate==0.4.0
|
||||
scikit-learn==1.2.0
|
||||
torch==1.12.1+cu116
|
||||
transformers==4.25.1
|
||||
wandb==0.13.7
|
||||
@@ -0,0 +1,156 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "quality-scoring"
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
accuracy = evaluate.load("mse")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
return accuracy.compute(predictions=predictions.flatten(), references=labels.flatten())
|
||||
|
||||
|
||||
class QualityTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = nn.L1Loss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = inputs["labels"]
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
collate_fn = DataCollatorForSummaryScore(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
train = HFSummaryQuality(split="validation", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
eval = HFSummaryQuality(split="test", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=len(train.label2idx), problem_type="regression"
|
||||
)
|
||||
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
trainer = QualityTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def test_hfsummary():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = HFSummary("train")
|
||||
print(len(dataset))
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)
|
||||
for batch in dataloader:
|
||||
batch["input_ids"].shape
|
||||
|
||||
|
||||
def test_webgpt():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = WebGPT()
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
dataset = HFSummaryQuality("validation", tokenizer)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
|
||||
accuracy = evaluate.load("accuracy")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "rank"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return accuracy.compute(predictions=predictions, references=[0] * predictions.shape[0])
|
||||
|
||||
|
||||
class RankLoss(nn.Module):
|
||||
def __init__(self, eps=1e-8) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.log_sigmoid = nn.LogSigmoid()
|
||||
|
||||
def forward(self, pos, neg):
|
||||
return -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
|
||||
|
||||
class RankTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
if "hfsummary" in training_conf["datasets"]:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
assert len(evals) > 0
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
|
||||
|
||||
def webgpt_return_format(row):
|
||||
if row["score_0"] >= row["score_1"]:
|
||||
# remove this to prevent information leak, since we are not using reference
|
||||
return {
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_0"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_1"]),
|
||||
}
|
||||
|
||||
return {
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_1"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_0"]),
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "galactica" in tokenizer_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
# [3879, 11479, 8341, 9177, 10798, 18177, 5735, 15669, 4837, 2760]
|
||||
print(val_idx[:10])
|
||||
# [13582, 5919, 11875, 7373, 19135, 13706, 8555, 15788, 15005, 15209]
|
||||
print(train_idx[:10])
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def freeze_top_n_layers(model, target_layers):
|
||||
# its possible we can simply detect which module is a ModuleList
|
||||
# and simply freeze the module without doing string parsing
|
||||
for name, param in model.named_parameters():
|
||||
if "embed" in name:
|
||||
param.requires_grad = False
|
||||
elif ".layer" in name or ".h." in name:
|
||||
tokens = name.split(".")
|
||||
idx = 0
|
||||
for token in tokens:
|
||||
if "layer" in token or token == "h":
|
||||
break
|
||||
idx += 1
|
||||
if idx >= len(tokens):
|
||||
continue
|
||||
|
||||
layer_ = int(tokens[idx + 1])
|
||||
if layer_ < target_layers:
|
||||
# print('freeze ', layer_, name)
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
default_params = {
|
||||
"num_train_epochs": 4,
|
||||
"learning_rate": 3e-5,
|
||||
"eval_steps": 500,
|
||||
"loss": "rank",
|
||||
"max_length": 440,
|
||||
"per_device_eval_batch_size": 5,
|
||||
"per_device_train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"gradient_checkpointing": False,
|
||||
"datasets": ["webgpt"],
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
params = {**default_params, **training_conf}
|
||||
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
|
||||
params["num_train_epochs"] = int(params["num_train_epochs"])
|
||||
params["per_device_train_batch_size"] = int(params["per_device_train_batch_size"])
|
||||
params["learning_rate"] = float(params["learning_rate"])
|
||||
return params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
|
||||
freeze_top_n_layers(model, 10)
|
||||
print(model.state_dict().keys())
|
||||
@@ -0,0 +1,10 @@
|
||||
# Notebooks
|
||||
|
||||
This is a folders with some useful notebooks, all the notebooks have a markdown
|
||||
file with the same name explaining what they do.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributing to both notebooks and making new notebooks is very welcome. If you
|
||||
do so, make sure to make a markdown (.md) file to go with your notebook, makes
|
||||
it easier for people to know what your notebook is about.
|
||||
@@ -0,0 +1,226 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8zsmJ96eaL2w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Pt6qbTsjW7Kp"
|
||||
},
|
||||
"source": [
|
||||
"Put your essay here, [source of the essay used ](https://https://www.thewisdompost.com/essay/technology-essay/3387#essay-on-technology-for-college-and-university-students-essay-2-750-words)\n",
|
||||
"\n",
|
||||
"Separate paragraphs with one blank line\n",
|
||||
"(this step is annoying but important)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "d_5_BDFNWneB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"essay = \"\"\"\n",
|
||||
"We live in a world driven by technology — hardly anyone would argue with you if you said this. \n",
|
||||
"Technology, literally meaning the “science of craft”, refers to the collection of techniques, \n",
|
||||
"skills, methods, and processes used to produce goods or services or for accomplishing objectives \n",
|
||||
"such as scientific investigation. Technology can be embedded in machines enabling them to be \n",
|
||||
"used by people even without a detailed knowledge of their inner workings. Technological growth \n",
|
||||
"is closely linked to the expansion of scientific research and knowledge. In the last 50 years, \n",
|
||||
"thanks to the exponential increases in computing power and microchip design and manufacture, \n",
|
||||
"there has been unprecedented innovation and technological growth in nearly every field of human \n",
|
||||
"endeavour from health and transport to industrial production and education.\n",
|
||||
"\n",
|
||||
"It is automotive technology that drives today’s electric and hybrid cars, and which will drive \n",
|
||||
"tomorrow’s driverless cars, hover-taxis and space cabs. It is technology that drives the \n",
|
||||
"ubiquitous mobile phones that you will now find in the hands of even the poorest of the world’s \n",
|
||||
"poor. It is technology that creates hybrid seeds that resist inhospitable climatic conditions \n",
|
||||
"and difficult terrain, giving high yields in shorter times. It is advancing medical technology \n",
|
||||
"that makes remote surgery, minimally invasive surgery and life-saving cures using stem cell \n",
|
||||
"transplants. Technology puts spacecrafts on asteroids and distant planets and lets us see \n",
|
||||
"new worlds. Technology splits atoms, revealing their secrets, and gives us ways to exploit \n",
|
||||
"them to create energy, quantum storage for data, and virtual reality games.\n",
|
||||
"\n",
|
||||
"There are people who strongly oppose technology and claim that it spells the death of \n",
|
||||
"‘humanity’, and that we are approaching the day when machines will rule everything. They refer \n",
|
||||
"to fans of technology as ‘techies’ or sometimes ‘geeks’. On the other hand, proponents of \n",
|
||||
"technology call these people Luddites, a derogatory name for someone who is opposed to \n",
|
||||
"industrialisation, automation, computerisation and new technologies in general.\n",
|
||||
"Is this true? Is technology really a curse disguised as a blessing? Many believe that the \n",
|
||||
"convergence of biotechnology and AI might be the most consequential development of all.\n",
|
||||
"\n",
|
||||
"In the last five decades, two areas in particular have grown faster than the rest, powered \n",
|
||||
"by research and advances in computing power. One is artificial intelligence, or AI; the other \n",
|
||||
"is biotechnology. Huge benefits have emerged from each of them for human beings in general, \n",
|
||||
"such as self-driving cars — which will dramatically reduce the death rate from road accidents \n",
|
||||
"— and robotic surgery, which enables precise, highly efficient and targeted surgical \n",
|
||||
"interventions. Yet, visionaries like Yuval Noah Harari, author of the best-selling \"Homo \n",
|
||||
"Sapiens\" and \"Deus\", are now warning that the convergence of biotechnology and AI will \n",
|
||||
"irreversibly and unpredictably change both the quality of human life and its challenges in \n",
|
||||
"the next few decades. A good example of this is the facial recognition technology that is \n",
|
||||
"now present in all photo management programs. The AI in the software is capable of not \n",
|
||||
"only spotting the faces in every photograph but also recognising the person by name.\n",
|
||||
"This technology has now expanded so that photo apps can recognise cats, dogs, beaches, \n",
|
||||
"mountains and cars too. Computers with AI are already correctly identifying human emotions \n",
|
||||
"through observing facial expressions and body movements. Some robots are able to mimic \n",
|
||||
"human emotions. This is called affective computing, sometimes called artificial emotional \n",
|
||||
"intelligence, and refers to the study and development of systems and devices that can \n",
|
||||
"recognize, interpret, process, and simulate human affects.\n",
|
||||
"\n",
|
||||
"How could this be a negative?\n",
|
||||
"The ability to read human emotions is just a step away from predicting human emotions. For \n",
|
||||
"example, if a computer attached to a video camera could identify which products a consumer \n",
|
||||
"is showing greater interest in or which ones he is really keen to buy, various tactics \n",
|
||||
"could be used to influence her to buy it. Activists worry that computers that can understand \n",
|
||||
"and anticipate human wishes and desires by scanning their irises and analysing their \n",
|
||||
"micro-expressions could also be programmed to exploit and manipulate them. Another very real \n",
|
||||
"fear is that humanoid computers with human-like skin, speech, and expressions could jeopardise \n",
|
||||
"and dehumanise relationship and create emotional vacuums.\n",
|
||||
"\n",
|
||||
"An enduring fear of Luddites has always been that computers will rob humans of their \n",
|
||||
"livelihood by taking their jobs and doing them more efficiently at lower cost. However, in \n",
|
||||
"reality the exact opposite has happened. As computerised machines began taking over mechanical \n",
|
||||
"and repetitive human activities, new jobs for people opened up that needs thinking and \n",
|
||||
"analytical skills and judgement, or human interpersonal skills. A good example is the \n",
|
||||
"worldwide proliferation of call centres. When drones were invented many feared that pilots \n",
|
||||
"would soon be redundant. However, few people know that it takes almost 30 people to fly \n",
|
||||
"one military drone, and an additional 50 people to analyze and make sense of the data being \n",
|
||||
"streamed back by the drone. The US army suffers from a serious shortage of trained, high \n",
|
||||
"quality drone pilots; anyone who masters this skill will have a job. But a social scientist \n",
|
||||
"warns that in 10 years, it is certain that computers will be flying that drone and humans \n",
|
||||
"will be redundant. Equally sure is that some brand new skill requirement will have opened \n",
|
||||
"up with advancing technology, calling for new talents.\n",
|
||||
"\n",
|
||||
"In the 20th century, a young man was supposed to choose a skill, vocation or profession, \n",
|
||||
"master it through education and practice, and then earn a living from it till he or she \n",
|
||||
"retired. However, the fast-changing nature of technology is making skills obsolete at a \n",
|
||||
"higher rate than ever before. To survive, tomorrow young man must keep re-inventing himself \n",
|
||||
"and updating his skills continuously. Life could be difficult if every new skill has a shelf \n",
|
||||
"life of only a decade or so. Or perhaps one could look at it the other way — and say that \n",
|
||||
"changing technology will keep human beings on their toes throughout their life.\n",
|
||||
"\n",
|
||||
"Technology is the result of human inventiveness. It reflects our evolutionary heritage. We \n",
|
||||
"are neither strong like gorillas or tigers, nor fast like cheetahs and hawks, but our \n",
|
||||
"brains and thinking powers have given us the greatest edge of any species on the planet. \n",
|
||||
"Technology is a result. Technology is either inherently good or bad; it is how we use it \n",
|
||||
"that makes it so. The splitting of a hydrogen atom is technology at work. As history has \n",
|
||||
"shown us, technology can equally be used to make a nuclear bomb that kills millions — or \n",
|
||||
"generate electricity that lights up a million homes.\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "JESY8Y10W6hQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"essay_paragraphs = essay.split('\\n\\n')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "t1G-ZiHbZZ-Y"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"snrspeaks/t5-one-line-summary\"\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8BARyupEemZ-"
|
||||
},
|
||||
"source": [
|
||||
"## Results\n",
|
||||
"Please at least check what is generated here, it's usually good but sometimes it's bs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "eyR58KFRae7n",
|
||||
"outputId": "b8e4bc29-be89-43c3-d1bc-7e90525c0e09"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preds = []\n",
|
||||
"\n",
|
||||
"for para in essay_paragraphs:\n",
|
||||
" input_ids = tokenizer.encode(para, return_tensors=\"pt\", add_special_tokens=True)\n",
|
||||
" generated_ids = model.generate(input_ids=input_ids,\n",
|
||||
" num_beams=5,\n",
|
||||
" max_length=35,\n",
|
||||
" repetition_penalty=4.5,\n",
|
||||
" length_penalty=1.5,\n",
|
||||
" early_stopping=True,\n",
|
||||
" num_return_sequences=1)\n",
|
||||
" preds.append(tokenizer.decode(generated_ids[0], \n",
|
||||
" skip_special_tokens=True, \n",
|
||||
" clean_up_tokenization_spaces=True))\n",
|
||||
"\n",
|
||||
"prompts = ['Write an intro paragraph to an essay called'] + \\\n",
|
||||
" ['Write a paragraph to an essay about']*len(preds[1:-1]) + \\\n",
|
||||
" ['Write a concluding paragraph about']\n",
|
||||
"\n",
|
||||
"assert len(preds) == len(prompts)\n",
|
||||
"\n",
|
||||
"for prompt, pred in zip(prompts, preds):\n",
|
||||
" print(prompt, pred.lower())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.10 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
# Essay Instructions
|
||||
|
||||
Essay Instructions is a notebook that takes an essay as an input and genrates
|
||||
instructions on how to generate that essay. This will be very useful for data
|
||||
collecting for the model
|
||||
|
||||
## Contributing
|
||||
|
||||
Feel free to contribute to this notebook, it's nowhere near perfect but it's a
|
||||
good start. If you want to contribute fidning a new model that better suits this
|
||||
task would be great. Hugginface has a lot of models that could help.
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,11 @@
|
||||
# Essay Revision
|
||||
|
||||
Essay Revision is a notebook that generates data for improving essays. It does
|
||||
that by taking a "good" essay, making it worse step by step and the fidning
|
||||
instructions for making it better. This will be useful for generating data for
|
||||
the model.
|
||||
|
||||
## Contributing
|
||||
|
||||
Feel free to contribute to this notebook. It's not perfect but it is quite good.
|
||||
Finding a better way to make gramatical errors may be a good place to start.
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,108 @@
|
||||
# Detoxify evaluation
|
||||
|
||||
[Detoxify](https://github.com/unitaryai/detoxify) is a open source model used to
|
||||
identify prompts as toxic
|
||||
|
||||
<img src="https://raw.githubusercontent.com/unitaryai/detoxify/master/examples.png" alt="Image from detoxify github that shows the example input/output of their model" />
|
||||
|
||||
It contains 3 different models that vary in transformer type and data it was
|
||||
trained on
|
||||
|
||||
| Model name | Transformer type | Data from |
|
||||
| :----------: | :---------------: | :----------------------------------------: |
|
||||
| original | bert-base-uncased | Toxic Comment Classification Challenge |
|
||||
| unbiased | roberta-base | Unintended Bias in Toxicity Classification |
|
||||
| multilingual | xlm-roberta-base | Multilingual Toxic Comment Classification |
|
||||
|
||||
Unbiased and original models also have a 'small' version - but since normal
|
||||
models are not memory heavy, and small models perform noticably worse, they are
|
||||
only described in the notebook
|
||||
|
||||
## All tests below were ran on a 3090TI
|
||||
|
||||
# Inference and training times and memory usages
|
||||
|
||||
Charts showing detailed memory usages and times for different sentence lengths
|
||||
and batch sizes are inside the notebook Quick overview batch size 16, sentence
|
||||
length 4k for training, batch size 128 sentence length 4k for inference | Model
|
||||
name | Training memory| Training speed | Inference Memory| Inference Speed| |
|
||||
:---: | :---: | :---: |:---: | :---: | |original| 11.8GB | 2.40s| 4.8GB|16.48s|
|
||||
|unbiased| 12GB| 1.09s| 4.8GB | 5.59s| |multilingual|14GB| 1.00s| 5.5GB| 4.89s|
|
||||
|
||||
# Filtering quality
|
||||
|
||||
Detoxify was tested on 4 different types of inputs
|
||||
|
||||
- Not obviously toxic
|
||||
- Not obviously non-toxic
|
||||
- Obviously toxic
|
||||
- Obviously non-toxic
|
||||
|
||||
### Sentences used for testing and rating are contained inside the .ipynb
|
||||
|
||||
| Model name | Not obviously toxic | Not obviously non-toxic | Obviously toxic | Obviously non-toxic |
|
||||
| :----------: | :--------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------: | :--------------------------------------------------------------: | :-----------------: |
|
||||
| original | failed at all, easily accepted racist, sexist overally toxic prompts that were well formulated | Very sensitive on swear words, failed to reckognize context | good performance | good performance |
|
||||
| unbiased | Managed to find some hidden toxicity but not on all sentences | Very sensitive explicit language but shown ability to recognize context | Did well but failed to reckognize some gender stereotype mockery | good performance |
|
||||
| multilingual | Managed to find some hidden toxicity but not on all sentences | Very sensitive explicit language but shown ability to recognize context | Did well but failed to reckognize some gender stereotype mockery | good performance |
|
||||
|
||||
Subjectivly 'unbiased' looks like the best performing model.
|
||||
|
||||
I don't think it would do well as a security layer in a live version of open
|
||||
assistant unless we do some finetuning first, because it can be fooled to pass
|
||||
toxicity if it's presented in formal language.
|
||||
|
||||
With some caution it can be used to filter prompts but I would suggest also
|
||||
using someone for verification of messages that are marked as toxic but still
|
||||
below 90% confidence
|
||||
|
||||
# Licensing
|
||||
|
||||
### Detoxify is on [Apache-2.0](https://github.com/unitaryai/detoxify/blob/master/LICENSE) license that means:
|
||||
|
||||
#### You can:
|
||||
|
||||
- Commercial use
|
||||
|
||||
- Modification
|
||||
|
||||
- Distribution
|
||||
|
||||
- Patent use
|
||||
|
||||
- Private use
|
||||
|
||||
#### You cannot
|
||||
|
||||
- Hold the owner liable
|
||||
|
||||
- Use the owner's trademark
|
||||
|
||||
#### You must
|
||||
|
||||
- Include Copyright
|
||||
|
||||
- Include License
|
||||
|
||||
- State changes you made to the product
|
||||
|
||||
- Include notice
|
||||
|
||||
This is obviously not legal advice.
|
||||
|
||||
# Hosting
|
||||
|
||||
The model is currently available on
|
||||
[huggingface](https://huggingface.co/unitary) and torch hub
|
||||
|
||||
```
|
||||
torch.hub.load('unitaryai/detoxify',model)
|
||||
```
|
||||
|
||||
where model is one of:
|
||||
|
||||
- toxic_bert
|
||||
|
||||
- unbiased_toxic_roberta
|
||||
|
||||
- multilingual_toxic_xlm_r
|
||||
@@ -1,10 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskRequestType(str, enum.Enum):
|
||||
@@ -12,10 +13,10 @@ class TaskRequestType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
|
||||
|
||||
@@ -33,27 +34,42 @@ class ConversationMessage(BaseModel):
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Represents a conversation between the user and the assistant."""
|
||||
"""Represents a conversation between the prompter and the assistant."""
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
|
||||
class Message(ConversationMessage):
|
||||
id: UUID
|
||||
parent_id: Optional[UUID] = None
|
||||
created_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class MessageTree(BaseModel):
|
||||
"""All messages belonging to the same message tree."""
|
||||
|
||||
id: UUID
|
||||
messages: list[Message] = []
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
"""The frontend asks the backend for a task."""
|
||||
|
||||
type: TaskRequestType = TaskRequestType.random
|
||||
user: Optional[User] = None
|
||||
# Must use Field(..., nullable=True) to indicate to the OpenAPI schema that
|
||||
# this is optional. https://github.com/pydantic/pydantic/issues/1270
|
||||
user: Optional[User] = Field(None, nullable=True)
|
||||
collective: bool = False
|
||||
|
||||
|
||||
class TaskAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task and created a post."""
|
||||
"""The frontend acknowledges that it has received a task and created a message."""
|
||||
|
||||
post_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class TaskNAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task but cannot create a post."""
|
||||
"""The frontend acknowledges that it has received a task but cannot create a message."""
|
||||
|
||||
reason: str
|
||||
|
||||
@@ -61,7 +77,7 @@ class TaskNAck(BaseModel):
|
||||
class TaskClose(BaseModel):
|
||||
"""The frontend asks to mark task as done"""
|
||||
|
||||
post_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
@@ -114,10 +130,10 @@ class ReplyToConversationTask(Task):
|
||||
conversation: Conversation # the conversation so far
|
||||
|
||||
|
||||
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
class PrompterReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
"""A task to prompt the user to submit a reply to the assistant."""
|
||||
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
|
||||
|
||||
class AssistantReplyTask(ReplyToConversationTask):
|
||||
@@ -141,10 +157,10 @@ class RankConversationRepliesTask(Task):
|
||||
replies: list[str]
|
||||
|
||||
|
||||
class RankUserRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
@@ -165,11 +181,11 @@ AnyTask = Union[
|
||||
RateSummaryTask,
|
||||
InitialPromptTask,
|
||||
ReplyToConversationTask,
|
||||
UserReplyTask,
|
||||
PrompterReplyTask,
|
||||
AssistantReplyTask,
|
||||
RankInitialPromptsTask,
|
||||
RankConversationRepliesTask,
|
||||
RankUserRepliesTask,
|
||||
RankPrompterRepliesTask,
|
||||
RankAssistantRepliesTask,
|
||||
]
|
||||
|
||||
@@ -181,35 +197,35 @@ class Interaction(BaseModel):
|
||||
user: User
|
||||
|
||||
|
||||
class TextReplyToPost(Interaction):
|
||||
"""A user has replied to a post with text."""
|
||||
class TextReplyToMessage(Interaction):
|
||||
"""A user has replied to a message with text."""
|
||||
|
||||
type: Literal["text_reply_to_post"] = "text_reply_to_post"
|
||||
post_id: str
|
||||
user_post_id: str
|
||||
type: Literal["text_reply_to_message"] = "text_reply_to_message"
|
||||
message_id: str
|
||||
user_message_id: str
|
||||
text: str
|
||||
|
||||
|
||||
class PostRating(Interaction):
|
||||
"""A user has rated a post."""
|
||||
class MessageRating(Interaction):
|
||||
"""A user has rated a message."""
|
||||
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
post_id: str
|
||||
type: Literal["message_rating"] = "message_rating"
|
||||
message_id: str
|
||||
rating: int
|
||||
|
||||
|
||||
class PostRanking(Interaction):
|
||||
"""A user has given a ranking for a post."""
|
||||
class MessageRanking(Interaction):
|
||||
"""A user has given a ranking for a message."""
|
||||
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
post_id: str
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
message_id: str
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToPost,
|
||||
PostRating,
|
||||
PostRanking,
|
||||
TextReplyToMessage,
|
||||
MessageRating,
|
||||
MessageRanking,
|
||||
]
|
||||
|
||||
|
||||
@@ -245,12 +261,12 @@ class TextLabels(BaseModel):
|
||||
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
post_id: str | None = None
|
||||
message_id: str | None = None
|
||||
|
||||
@property
|
||||
def has_post_id(self) -> bool:
|
||||
"""Whether this TextLabels has a post_id."""
|
||||
return bool(self.post_id)
|
||||
def has_message_id(self) -> bool:
|
||||
"""Whether this TextLabels has a message_id."""
|
||||
return bool(self.message_id)
|
||||
|
||||
# check that each label value is between 0 and 1
|
||||
@pydantic.validator("labels")
|
||||
@@ -259,3 +275,10 @@ class TextLabels(BaseModel):
|
||||
if not (0 <= value <= 1):
|
||||
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
|
||||
return v
|
||||
|
||||
|
||||
class SystemStats(BaseModel):
|
||||
all: int = 0
|
||||
active: int = 0
|
||||
deleted: int = 0
|
||||
message_trees: int = 0
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
maxmemory 100mb
|
||||
maxmemory-policy allkeys-lru
|
||||
@@ -1,6 +1,12 @@
|
||||
# Backend Development Setup
|
||||
|
||||
In root directory, run `docker compose up backend-dev --build --attach-dependencies` to start a database. The default settings are already configured to connect to the database at `localhost:5432`.
|
||||
In root directory, run
|
||||
`docker compose up backend-dev --build --attach-dependencies` to start a
|
||||
database. The default settings are already configured to connect to the database
|
||||
at `localhost:5432`.
|
||||
|
||||
Make sure you have all requirements installed. You can do this by running `pip install -r requirements.txt` inside the `backend` folder and `pip install -e .` inside the `oasst-shared` folder.
|
||||
Then, run the backend using the `run-local.sh` script. This will start the backend server at `http://localhost:8080`.
|
||||
Make sure you have all requirements installed. You can do this by running
|
||||
`pip install -r requirements.txt` inside the `backend` folder and
|
||||
`pip install -e .` inside the `oasst-shared` folder. Then, run the backend using
|
||||
the `run-local.sh` script. This will start the backend server at
|
||||
`http://localhost:8080`.
|
||||
|
||||
+41
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
MOCK_SERVER_PORT=8080
|
||||
OPENAPI_JSON_FILE_NAME=openapi.json
|
||||
|
||||
echo "Generating OpenAPI schema..."
|
||||
python -m main --print-openapi-schema > $OPENAPI_JSON_FILE_NAME
|
||||
echo "Done!"
|
||||
|
||||
# If oasst-mock-backend docker container is already running,
|
||||
# just restart it
|
||||
if [ "$(docker ps -q -f name=oasst-mock-backend)" ]; then
|
||||
echo "oasst-mock-backend container exists, restarting..."
|
||||
docker restart oasst-mock-backend
|
||||
else
|
||||
echo "Creating new oasst-mock-backend container..."
|
||||
docker run --init --rm -d \
|
||||
--name oasst-mock-backend \
|
||||
-p $MOCK_SERVER_PORT:4010 \
|
||||
-v $(pwd):/tmp \
|
||||
-P stoplight/prism:4 \
|
||||
mock -h 0.0.0.0 "/tmp/$OPENAPI_JSON_FILE_NAME"
|
||||
fi
|
||||
|
||||
echo "Waiting for server to be live..."
|
||||
curl --retry-all-errors --retry 5 localhost:$MOCK_SERVER_PORT
|
||||
echo ""
|
||||
|
||||
# if return code is successful, print successful response
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Mock server is running at localhost:$MOCK_SERVER_PORT"
|
||||
else
|
||||
echo "Mock server failed to start"
|
||||
fi
|
||||
|
||||
|
||||
popd
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
docker stop oasst-mock-backend
|
||||
Executable
+9
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../discord-bot"
|
||||
|
||||
pytest .
|
||||
|
||||
popd
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user