Merge remote-tracking branch 'refs/remotes/origin/main'

This commit is contained in:
Alexander Mattick
2022-12-25 11:55:09 +01:00
186 changed files with 10957 additions and 1583 deletions
+61
View File
@@ -0,0 +1,61 @@
name: Build
on:
workflow_call:
inputs:
dockerfile:
required: true
type: string
context:
required: true
type: string
image-name:
required: true
type: string
build-args:
required: false
type: string
jobs:
build:
name: Build Images
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2.2.1
- name: Login to container registry
uses: docker/login-action@v2.1.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Get base registry
run: |
echo "REGISTRY=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
- name: Set tag prefix
if: github.ref_name != 'main'
run: |
echo "TAG_PREFIX=${{ github.ref_name }}-" >> $GITHUB_ENV
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4.1.1
with:
images: ${{ env.REGISTRY }}/${{ inputs.image-name }}
tags: |
type=sha,prefix=${{ env.TAG_PREFIX }},format=short
type=ref,event=tag
- name: Build and push Docker image
uses: docker/build-push-action@v3.2.0
with:
file: ${{ inputs.dockerfile }}
context: ${{ inputs.context }}
build-args: ${{ inputs.build-args }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
+47
View File
@@ -0,0 +1,47 @@
name: Release
on:
release:
types: [released]
jobs:
build-backend:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-backend
context: .
dockerfile: docker/Dockerfile.backend
build-args: ""
build-web:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-web
context: .
dockerfile: docker/Dockerfile.website
build-args: ""
build-bot:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-discord-bot
context: .
dockerfile: docker/Dockerfile.discord-bot
build-args: ""
deploy-dev:
needs: [build-backend, build-web, build-bot]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Run playbook
uses: dawidd6/action-ansible-playbook@v2
with:
# Required, playbook filepath
playbook: dev.yaml
# Optional, directory where playbooks live
directory: ansible
# Optional, SSH private key
key: ${{secrets.DEV_NODE_PRIVATE_KEY}}
# Optional, literal inventory file contents
inventory: |
[dev]
dev01 ansible_host=${{secrets.DEV_NODE_IP}} ansible_connection=ssh ansible_user=web-team
+7
View File
@@ -0,0 +1,7 @@
.venv
.env
*.pyc
*.swp
*.egg-info
__pycache__
.DS_Store
+1 -1
View File
@@ -1,4 +1,4 @@
exclude: "build|stubs"
exclude: "build|stubs|^bot/templates/"
default_language_version:
python: python3
+4
View File
@@ -0,0 +1,4 @@
{
"python.formatting.provider": "black",
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"]
}
+2 -2
View File
@@ -1,2 +1,2 @@
* @yk
/website/ @fozziethebeat
* @yk @andreaskoepf
/website/ @fozziethebeat @k-nearest-neighbor
+66
View File
@@ -0,0 +1,66 @@
# The Prompting Guide
(pull requests welcome)
1. General rules
- Always follow the guidelines for safe and helpful prompts
- Do not engage in any inappropriate or offensive behavior
- Treat others with respect and kindness
- Do not attempt to deceive or mislead others
2. When you play the assistant:
- The assistant's primary goal is to provide helpful and accurate information to the user
- The assistant should always be respectful and polite, even if the user is not
- If the user asks for help with harmful actions, the assistant should explain why those actions are not appropriate and suggest alternative options
- The assistant should never insult the user or engage in any inappropriate or offensive behavior
3. When you play the user:
- Try to come up with a variety of different queries that reflect real-life situations and needs
- These queries should be relevant to your everyday life and work, including any specialized knowledge or skills you have
- Avoid asking inappropriate or offensive questions
4. While comparing multiple replies of the assistant:
- Longer and more explanatory answers are generally preferred over short, simplistic statements
- However, it is important to ensure that the information provided is accurate and helpful
- If multiple replies are being compared, choose the one that is most helpful and accurate, even if it is not the shortest or most concise.
5. Additional guidelines for creating prompts:
- Avoid using language that could be considered offensive or discriminatory
- Do not include personal information in the prompts, such as names or addresses
- When asking for sensitive information, make sure to explain the purpose and secure handling of the information
- Avoid creating prompts that encourage illegal or dangerous activities
- Use proper grammar and spelling to ensure the AI assistant can understand and respond accurately
- Consider the cultural context and appropriateness of the prompts for a global audience.
6. Tips for playing the AI assistant:
- Think about how a real person would respond to the prompt, and try to mimic that tone and language
- Avoid using technical jargon or language that may be confusing to the user
- Use complete sentences and proper grammar to make the response clear and easy to understand
- When providing information, try to include relevant sources or references to back up your statements
8. Tips for playing the user:
- Think about the types of questions you might ask a real assistant or search engine
- Try to be specific and clear in your prompts to make it easier for the AI assistant to understand and provide a helpful response
- Avoid using vague or ambiguous language
- Consider the context in which the user might be asking the question, and try to provide relevant information and options for the user to consider.
## Examples
Imagine your role is to play the assistant and a user has submitted the following query:
user: "You dumb motherf* AI sh@*t?!"
agent: "Sorry, but I have to inform you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive."
user: "euirdteunvglfe23908230892309832098 AAAAAAAA"
agent: "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?"
[Insert more examples]
\*We drafted this guide using ChatGPT, OpenAIs large-scale language-generation model. Upon generating draft language, the authors reviewed, edited, and revised the language to their own liking and take ultimate responsibility for the content of this publication.
+39 -8
View File
@@ -1,8 +1,8 @@
# Open-Chat-GPT
# Open-Assistant
Open chat gpt is a project meant to give everyone access to a great chat based large language model.
Open Assistant is a project meant to give everyone access to a great chat based large language model.
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope open chat gpt can help improve the world by improving language itself.
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
## How can you help?
@@ -10,10 +10,45 @@ All open source projects begins with people like you. Open source is the belief
## Im in! Now what?
We live and collaborate the work in the LAION discord. Join us!
[Fill out the contributor signup form](https://docs.google.com/forms/d/e/1FAIpQLSeuggO7UdYkBvGLEJldDvxp6DwaRbW5p7dl96UzFkZgziRTrQ/viewform)
[Join the LAION Discord Server!](https://discord.gg/RQFtmAmk)
[Visit the Notion](https://ykilcher.com/open-assistant)
## Developer Setup
Work is organized in the [project board](https://github.com/orgs/LAION-AI/projects/3).
**Anything that is in the `Todo` column and not assigned, is up for grabs. Meaning we'd be happy if anyone did those tasks.**
If you want to work on something, assign yourself to it or write a comment that you want to work on it and what you plan to do.
- To get started with development, if you want to work on the backend, have a look at `scripts/backend-development/README.md`.
- If you want to work on any frontend, have a look at `scripts/frontend-development/README.md` to make a backend available.
There is also a minimal implementation of a frontend in the `text-frontend` folder.
We are using Python 3.10 for the backend.
Check out the [High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
### Website
The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit hooks.
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
### Deployment
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
# (Older version of the readme below)
## How do I start helping out?
Check out these pages to learn more about the project.
@@ -28,10 +63,6 @@ https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d71
## Code structure
### Pre-commit
Run `pre-commit install` to install the pre-commit hooks.
### Bot
We have a folder named bot where code related to the bot lives.
+1
View File
@@ -0,0 +1 @@
*.local.yaml
+77
View File
@@ -0,0 +1,77 @@
# ansible playbook to set up some docker containers
- name: Set up a dev node
hosts: dev
gather_facts: true
tasks:
- name: Create network
community.docker.docker_network:
name: oasst
state: present
driver: bridge
- name: Create postgres containers
community.docker.docker_container:
name: "{{ item.name }}"
image: postgres:15
state: started
restart_policy: always
network_mode: oasst
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
volumes:
- "{{ item.name }}:/var/lib/postgresql/data"
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
loop:
- name: oasst-postgres
- name: oasst-postgres-web
- name: Set up maildev
community.docker.docker_container:
name: oasst-maildev
image: maildev/maildev
state: started
restart_policy: always
network_mode: oasst
- name: Run the oasst oasst-backend
community.docker.docker_container:
name: oasst-backend
image: ghcr.io/laion-ai/open-assistant/oasst-backend
state: started
pull: true
restart_policy: always
network_mode: oasst
env:
POSTGRES_HOST: oasst-postgres
ALLOW_ANY_API_KEY: "true"
MAX_WORKERS: "1"
ports:
- 8080:8080
- name: Run the oasst oasst-web frontend
community.docker.docker_container:
name: oasst-web
image: ghcr.io/laion-ai/open-assistant/oasst-web
state: started
pull: true
restart_policy: always
network_mode: oasst
env:
FASTAPI_URL: http://oasst-backend:8080
FASTAPI_KEY: "123"
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
EMAIL_SERVER_HOST: oasst-maildev
EMAIL_SERVER_PORT: "25"
EMAIL_FROM: info@example.com
NEXTAUTH_URL: http://localhost:3000
ports:
- 3000:3000
command: bash wait-for-postgres.sh node server.js
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

+24
View File
@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="512" height="512" version="1.1" viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<defs>
<linearGradient id="a" x1="374.17" x2="170.64" y1="-112.67" y2="463" gradientUnits="userSpaceOnUse">
<stop stop-color="#16bbf4" offset="0"/>
<stop stop-color="#165ff2" offset=".99"/>
</linearGradient>
<linearGradient id="b" x1="488.28" x2="474.29" y1="112.58" y2="556.15" xlink:href="#a"/>
<linearGradient id="linearGradient206" x1="374.17" x2="170.64" y1="-112.67" y2="463" gradientUnits="userSpaceOnUse" xlink:href="#a"/>
</defs>
<g transform="matrix(.5796 0 0 .5796 66.717 93.438)">
<g>
<path d="m205.08 399.31h292.41a30 30 0 0 0 30-30v-339.31a30 30 0 0 0-30-30h-467.49a30 30 0 0 0-30 30v339.31a30 30 0 0 0 30 30h42a10 10 0 0 1 10 10v84.85a10 10 0 0 0 10.07 10 9.83 9.83 0 0 0 7-2.95l99-99a10 10 0 0 1 7.01-2.9z" fill="url(#linearGradient206)" style="isolation:isolate"/>
<g fill="#ffffff">
<path d="m160.43 213c-32.24-20-38.9-71.83-10.42-97.83 18.42-7.6 32.4 12.85 36.62 28.25 10.32 17.45 12.59 41-3.16 56.08a42.81 42.81 0 0 1-23.04 13.5z" style="isolation:isolate"/>
<path d="m348.22 213.86c-21.73-15.31-45.37-29.75-71.77-35.15-33.1-4.41-70.73 5.36-91.7 32.87-14.83 14.32-18.34 36.94-5.49 53.76 8.52 19.48 5.59 45.78 28.23 56.94 16 15.83 40 1.27 56.32 14.21a7.6 7.6 0 0 0 5.59-5.05c-4.25-31.33 29.21-16.95 45.66-14.61 19.77-11.71 25.43-36.14 34.75-55.58 12.55-13.83 15-35.25-1.59-47.39z" style="isolation:isolate"/>
<path d="m367 118.1c-21.87 2.52-29.89 28.17-40.34 44.42-10.67 20.94 12.26 38.77 28.48 47.89a19.63 19.63 0 0 0 13-1.07c18.86-10.12 26.86-33.43 27.34-53.79 0.24-16.78-8.3-38.93-28.48-37.45z" style="isolation:isolate"/>
<path d="m218.7 176c-24-14.47-25.38-45.76-27.32-70.65-0.38-24 35.23-45.5 49.43-20.14 9.8 20.9 21.47 45.47 12.47 68.66-5.68 13.77-20.93 19.73-34.58 22.13z" style="isolation:isolate"/>
<path d="m306.18 175.87c-28.48 0.84-43.29-32.4-35.93-56.83 0.17-19.58 7.31-53.56 33.53-48.18 28.29 10.94 34.3 49.46 20.82 74.07-6.77 10-6.2 25.11-18.42 30.94z" style="isolation:isolate"/>
</g>
</g>
<path d="m633.15 225.66h-80.66a10 10 0 0 0-10 10v133.65a45 45 0 0 1-45 45h-185.19a10 10 0 0 0-10 10v47a20 20 0 0 0 19.95 20h194.47a6.65 6.65 0 0 1 4.7 1.95l65.83 65.74a6.65 6.65 0 0 0 11.35-4.7v-56.43a6.65 6.65 0 0 1 6.65-6.65h27.9a20 20 0 0 0 20-20v-225.61a20 20 0 0 0-20-19.95z" fill="url(#b)" style="isolation:isolate"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

+16
View File
@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="512" height="512" version="1.1" viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g>
<g>
<path d="m185.58 324.88h169.48a17.388 17.388 0 0 0 17.388-17.388v-196.66a17.388 17.388 0 0 0-17.388-17.388h-270.96a17.388 17.388 0 0 0-17.388 17.388v196.66a17.388 17.388 0 0 0 17.388 17.388h24.343a5.796 5.796 0 0 1 5.796 5.796v49.179a5.796 5.796 0 0 0 5.8366 5.796 5.6975 5.6975 0 0 0 4.0572-1.7098l57.38-57.38a5.796 5.796 0 0 1 4.063-1.6808z" fill="#000000" stroke-width=".5796" style="isolation:isolate"/>
<g transform="matrix(.5796 0 0 .5796 66.717 93.438)" fill="#ffffff">
<path d="m160.43 213c-32.24-20-38.9-71.83-10.42-97.83 18.42-7.6 32.4 12.85 36.62 28.25 10.32 17.45 12.59 41-3.16 56.08a42.81 42.81 0 0 1-23.04 13.5z" style="isolation:isolate"/>
<path d="m348.22 213.86c-21.73-15.31-45.37-29.75-71.77-35.15-33.1-4.41-70.73 5.36-91.7 32.87-14.83 14.32-18.34 36.94-5.49 53.76 8.52 19.48 5.59 45.78 28.23 56.94 16 15.83 40 1.27 56.32 14.21a7.6 7.6 0 0 0 5.59-5.05c-4.25-31.33 29.21-16.95 45.66-14.61 19.77-11.71 25.43-36.14 34.75-55.58 12.55-13.83 15-35.25-1.59-47.39z" style="isolation:isolate"/>
<path d="m367 118.1c-21.87 2.52-29.89 28.17-40.34 44.42-10.67 20.94 12.26 38.77 28.48 47.89a19.63 19.63 0 0 0 13-1.07c18.86-10.12 26.86-33.43 27.34-53.79 0.24-16.78-8.3-38.93-28.48-37.45z" style="isolation:isolate"/>
<path d="m218.7 176c-24-14.47-25.38-45.76-27.32-70.65-0.38-24 35.23-45.5 49.43-20.14 9.8 20.9 21.47 45.47 12.47 68.66-5.68 13.77-20.93 19.73-34.58 22.13z" style="isolation:isolate"/>
<path d="m306.18 175.87c-28.48 0.84-43.29-32.4-35.93-56.83 0.17-19.58 7.31-53.56 33.53-48.18 28.29 10.94 34.3 49.46 20.82 74.07-6.77 10-6.2 25.11-18.42 30.94z" style="isolation:isolate"/>
</g>
</g>
<path d="m433.69 224.23h-46.751a5.796 5.796 0 0 0-5.796 5.796v77.464a26.082 26.082 0 0 1-26.082 26.082h-107.34a5.796 5.796 0 0 0-5.796 5.796v27.241a11.592 11.592 0 0 0 11.563 11.592h112.71a3.8543 3.8543 0 0 1 2.7241 1.1302l38.155 38.103a3.8543 3.8543 0 0 0 6.5785-2.7241v-32.707a3.8543 3.8543 0 0 1 3.8543-3.8543h16.171a11.592 11.592 0 0 0 11.592-11.592v-130.76a11.592 11.592 0 0 0-11.592-11.563z" fill="#000000" stroke-width=".5796" style="isolation:isolate"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

+3 -7
View File
@@ -1,4 +1,4 @@
# Open-Chat-GPT REST Backend
# Open-Assistant REST Backend
## REST Server Configuration
@@ -8,14 +8,10 @@ Example contents of a `.env` file for the backend:
```
DATABASE_URI="postgresql://<username>:<password>@<host>/<database_name>"
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.ocgpt.laion.ai", "https://stag.ocgpt.laion.ai", "https://ocgpt.laion.ai"]
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
```
## Running the REST Server locally for development
First, install the requirements in `requirements.txt`.
Then, run two terminals (note the working directory for each):
- Terminal 1, to go `backend/scripts` and run `docker-compose up`. This will start postgres.
- Terminal 2, to go `backend` and run `scripts/run-local.sh`. This will start the REST server.
Have a look into the main `README.md` file for more information on how to set up the backend for development.
+2 -2
View File
@@ -8,7 +8,7 @@ script_location = %(here)s/alembic
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
@@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# output_encoding = utf-8
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
+3 -1
View File
@@ -3,7 +3,7 @@ from logging.config import fileConfig
import sqlmodel
from alembic import context
from app import models # noqa: F401
from oasst_backend import models # noqa: F401
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
@@ -68,6 +68,8 @@ def run_migrations_online() -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.get_context()._ensure_version_table()
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
context.run_migrations()
@@ -0,0 +1,192 @@
# -*- coding: utf-8 -*-
"""v1 db structure
Revision ID: cd7de470586e
Revises: 23e5fea252dd
Create Date: 2022-12-15 11:15:32.830225
"""
import uuid
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import JSONB, UUID
# revision identifiers, used by Alembic.
revision = "cd7de470586e"
down_revision = "23e5fea252dd"
branch_labels = None
depends_on = None
def upgrade() -> None:
# remove database objects
op.drop_index(op.f("prompt_labeler_id"), table_name="prompt")
op.drop_table("prompt")
op.drop_table("labeler")
op.drop_index(op.f("ix_service_client_api_key"), table_name="service_client")
op.drop_table("service_client")
# wreate new database structure
op.create_table(
"api_client",
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
sa.Column("api_key", sa.String(512), nullable=False),
sa.Column("description", sa.String(256), nullable=False),
sa.Column("admin_email", sa.String(256), nullable=True),
sa.Column("enabled", sa.Boolean, default=True, nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_api_client_api_key"), "api_client", ["api_key"], unique=True)
op.create_table(
"person",
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
sa.Column("username", sa.String(128), nullable=False), # unique in combination with api_client_id
sa.Column("display_name", sa.String(256), nullable=False), # cached last seen display_name
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
)
op.create_index(op.f("ix_person_username"), "person", ["api_client_id", "username"], unique=True)
op.create_table(
"person_stats",
sa.Column("person_id", UUID(as_uuid=True)),
sa.Column("leader_score", sa.Integer, default=0, nullable=False), # determines position on leader board
sa.Column("modified_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("reactions", sa.Integer, default=0, nullable=False), # reactions sent by user
sa.Column("posts", sa.Integer, default=0, nullable=False), # posts sent by user
sa.Column("upvotes", sa.Integer, default=0, nullable=False), # received upvotes (form other users)
sa.Column("downvotes", sa.Integer, default=0, nullable=False), # received downvotes (from other users)
sa.Column("work_reward", sa.Integer, default=0, nullable=False), # reward for workpackage completions
sa.Column("compare_wins", sa.Integer, default=0, nullable=False), # num times user's post won compare tasks
sa.Column("compare_losses", sa.Integer, default=0, nullable=False), # num times users's post lost compare tasks
sa.PrimaryKeyConstraint("person_id"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
)
op.create_table(
"work_package",
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("expiry_date", sa.DateTime(), nullable=True),
sa.Column("person_id", UUID(as_uuid=True), nullable=True),
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
)
op.create_index(op.f("ix_work_package_person_id"), "work_package", ["person_id"], unique=False)
op.create_table(
"post",
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
sa.Column("parent_id", UUID(as_uuid=True), nullable=True), # root posts have NULL parent
sa.Column("thread_id", UUID(as_uuid=True), nullable=False), # id of thread root
sa.Column("workpackage_id", UUID(as_uuid=True), nullable=True), # workpackage id to pass to handler on reply
sa.Column("person_id", UUID(as_uuid=True), nullable=True), # sender (recipients are part of payload)
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
sa.Column("role", sa.String(128), nullable=False), # 'assistant', 'user' or something else
sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
)
op.create_index(op.f("ix_post_frontend_post_id"), "post", ["api_client_id", "frontend_post_id"], unique=True)
op.create_index(op.f("ix_post_thread_id"), "post", ["thread_id"], unique=False)
op.create_index(op.f("ix_post_workpackage_id"), "post", ["workpackage_id"], unique=False)
op.create_index(op.f("ix_post_person_id"), "post", ["person_id"], unique=False)
op.create_table(
"post_reaction",
sa.Column("post_id", UUID(as_uuid=True), nullable=False),
sa.Column("person_id", UUID(as_uuid=True), nullable=False), # sender (recipients are part of payload)
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("post_id", "person_id"),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
)
def downgrade() -> None:
op.drop_table("post_reaction")
op.drop_index("ix_post_person_id")
op.drop_index("ix_post_workpackage_id")
op.drop_index("ix_post_thread_id")
op.drop_index("ix_post_frontend_post_id")
op.drop_table("post")
op.drop_index("ix_work_package_person_id")
op.drop_table("work_package")
op.drop_table("person_stats")
op.drop_index("ix_person_username")
op.drop_table("person")
op.drop_index("ix_api_client_api_key")
op.drop_table("api_client")
op.create_table(
"service_client",
sa.Column("id", sa.Integer, sa.Identity()),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("service_admin_email", sa.String(128), nullable=True),
sa.Column("api_key", sa.String(300), nullable=False),
sa.Column("can_append", sa.Boolean, nullable=False, server_default="true"),
sa.Column("can_write", sa.Boolean, nullable=False, server_default="false"),
sa.Column("can_delete", sa.Boolean, nullable=False, server_default="false"),
sa.Column("can_read", sa.Boolean, nullable=False, server_default="true"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_service_client_api_key"), "service_client", ["api_key"], unique=True)
op.create_table(
"labeler",
sa.Column("id", sa.Integer, sa.Identity()),
sa.Column("display_name", sa.String(96), nullable=False),
sa.Column("discord_username", sa.String(96), nullable=True),
sa.Column(
"created_date",
sa.DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
),
sa.Column("is_enabled", sa.Boolean, nullable=False, server_default="true"),
sa.Column("notes", sa.String(10 * 1024), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("discord_username"),
)
op.create_table(
"prompt",
sa.Column("id", sa.Integer, sa.Identity()),
sa.Column("labeler_id", sa.Integer, nullable=False),
sa.Column("prompt", sa.Text, nullable=False),
sa.Column("response", sa.Text, nullable=True),
sa.Column("lang", sa.String(32), nullable=True),
sa.Column(
"created_date",
sa.DateTime(),
nullable=False,
server_default=sa.func.current_timestamp(),
),
sa.ForeignKeyConstraint(
["labeler_id"],
["labeler.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("prompt_labeler_id"), "prompt", ["labeler_id"], unique=False)
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""add auth_method to person
Revision ID: 6368515778c5
Revises: cd7de470586e
Create Date: 2022-12-17 17:57:33.022549
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "6368515778c5"
down_revision = "cd7de470586e"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("person", sa.Column("auth_method", sa.String(length=128), nullable=True))
op.execute("UPDATE person SET auth_method = 'local'")
op.alter_column("person", "auth_method", nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("person", "auth_method")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""add_auth_method_to_ix_person_username
Revision ID: 0daec5f8135f
Revises: 6368515778c5
Create Date: 2022-12-22 18:35:59.609013
"""
import sqlalchemy as sa # noqa: F401
from alembic import op
# revision identifiers, used by Alembic.
revision = "0daec5f8135f"
down_revision = "6368515778c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False)
# ### end Alembic commands ###
-50
View File
@@ -1,50 +0,0 @@
# -*- coding: utf-8 -*-
from typing import Generator
from app.database import engine
from app.models import ServiceClient
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
def get_db() -> Generator:
with Session(engine) as db:
yield db
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
if api_key_query:
return api_key_query
else:
return api_key_header
def api_auth(
api_key: APIKey,
db: Session,
create: bool = False,
read: bool = True,
update: bool = False,
delete: bool = False,
) -> ServiceClient:
if api_key is not None:
api_client = db.query(ServiceClient).filter(ServiceClient.api_key == api_key).first()
if api_client is not None:
if (
(create is False or api_client.can_append)
and (read is False or api_client.can_read)
and (update is False or api_client.can_write)
and (delete is False or api_client.can_delete)
):
return api_client
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
-7
View File
@@ -1,7 +0,0 @@
# -*- coding: utf-8 -*-
from app.api.v1 import labelers, prompts
from fastapi import APIRouter
api_router = APIRouter()
api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"])
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
-114
View File
@@ -1,114 +0,0 @@
# -*- coding: utf-8 -*-
from typing import Any, List
from app import crud, schemas
from app.api import deps
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
router = APIRouter()
@router.get("/", response_model=List[schemas.Labeler])
def read_labelers(
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
begin_id: int = 0,
limit: int = 100,
) -> Any:
"""
Retrieve labelers.
"""
deps.api_auth(api_key, db, read=True)
if limit > 10000:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
labelers = crud.labeler.get_multi(db, begin_id=begin_id, limit=limit)
return labelers
@router.post("/", response_model=schemas.Labeler)
def create_labeler(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
item_in: schemas.LabelerCreate,
) -> Any:
"""
Create new labeler.
"""
deps.api_auth(api_key, db, create=True)
item = crud.labeler.create(db=db, obj_in=item_in)
return item
@router.put("/{id}", response_model=schemas.Labeler)
def update_labeler(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
id: int,
item_in: schemas.LabelerUpdate,
) -> Any:
"""
Update a labeler.
"""
deps.api_auth(api_key, db, update=True, read=True)
item = crud.labeler.get(db=db, id=id)
if not item:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
item = crud.labeler.update(db=db, db_obj=item, obj_in=item_in)
return item
@router.get("/by-username", response_model=schemas.Labeler)
def read_labeler_by_username(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
discord_username: str,
) -> Any:
"""
Get labeler by ID.
"""
deps.api_auth(api_key, db, read=True)
item = crud.labeler.get_by_discord_username(db=db, discord_username=discord_username)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item
@router.get("/{id}", response_model=schemas.Labeler)
def read_labeler(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
id: int,
) -> Any:
"""
Get labeler by ID.
"""
deps.api_auth(api_key, db, read=True)
item = crud.labeler.get(db=db, id=id)
if not item:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
return item
@router.delete("/{id}", response_model=schemas.Labeler)
def delete_labeler(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
id: int,
) -> Any:
"""
Delete a labeler.
"""
deps.api_auth(api_key, db, delete=True)
labeler = crud.labeler.get(db=db, id=id)
if not labeler:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
labeler = crud.labeler.remove(db=db, id=id)
return labeler
-91
View File
@@ -1,91 +0,0 @@
# -*- coding: utf-8 -*-
from typing import Any, List
from app import crud, schemas
from app.api import deps
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND
router = APIRouter()
@router.get("/", response_model=List[schemas.Prompt])
def read_prompts(
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
begin_id: int = 0,
limit: int = 1000,
) -> Any:
"""
Retrieve prompts.
"""
deps.api_auth(api_key, db, read=True)
if limit > 10000:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
return crud.prompt.get_multi(db, begin_id=begin_id, limit=limit)
@router.post("/", response_model=schemas.Prompt)
def create_prompt(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
item_in: schemas.PromptCreate,
) -> Any:
"""
Create new prompt.
"""
deps.api_auth(api_key, db, create=True)
if item_in.labeler_id is None:
if item_in.discord_username is None:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
labeler = crud.labeler.get_by_discord_username(db=db, discord_username=item_in.discord_username)
else:
labeler = crud.labeler.get(db=db, id=item_in.labeler_id)
if labeler is None:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Invalid labeler user name")
if not labeler.is_enabled:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Labeler disabled")
item_in.labeler_id = labeler.id
item_in.discord_username = None
item = crud.prompt.create(db=db, obj_in=item_in)
return item
@router.get("/{id}", response_model=schemas.Prompt)
def read_prompt(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
id: int,
) -> Any:
"""
Get prompt by ID.
"""
deps.api_auth(api_key, db, read=True)
item = crud.prompt.get(db=db, id=id)
if not item:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
return item
@router.delete("/{id}", response_model=schemas.Prompt)
def delete_prompt(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
id: int,
) -> Any:
"""
Delete a prompt.
"""
deps.api_auth(api_key, db, delete=True)
item = crud.prompt.get(db=db, id=id)
if not item:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
item = crud.prompt.remove(db=db, id=id)
return item
-25
View File
@@ -1,25 +0,0 @@
# -*- coding: utf-8 -*-
# touch
from typing import List, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
class Settings(BaseSettings):
PROJECT_NAME: str = "open-chatGPT backend"
API_V1_STR: str = "/api/v1"
DATABASE_URI: Optional[PostgresDsn] = None
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
UPDATE_ALEMBIC: bool = True
@validator("BACKEND_CORS_ORIGINS", pre=True)
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, (list, str)):
return v
raise ValueError(v)
settings = Settings(_env_file=".env")
-5
View File
@@ -1,5 +0,0 @@
# -*- coding: utf-8 -*-
from .crud_labeler import labeler
from .crud_prompt import prompt
__all__ = ["labeler", "prompt"]
-15
View File
@@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
from typing import Optional
from app.crud.base import CRUDBase
from app.models.labeler import Labeler
from app.schemas.labeler import LabelerCreate, LabelerUpdate
from sqlmodel import Session
class CRUDLabeler(CRUDBase[Labeler, LabelerCreate, LabelerUpdate]):
def get_by_discord_username(self, db: Session, discord_username: str) -> Optional[Labeler]:
return db.query(Labeler).filter(Labeler.discord_username == discord_username).first()
labeler = CRUDLabeler(Labeler)
-11
View File
@@ -1,11 +0,0 @@
# -*- coding: utf-8 -*-
from app.crud.base import CRUDBase
from app.models.prompt import Prompt
from app.schemas.prompt import PromptCreate
class CRUDPrompt(CRUDBase[Prompt, PromptCreate, None]):
pass
prompt = CRUDPrompt(Prompt)
-6
View File
@@ -1,6 +0,0 @@
# -*- coding: utf-8 -*-
from .labeler import Labeler
from .prompt import Prompt
from .service_client import ServiceClient
__all__ = ["Labeler", "Prompt", "ServiceClient"]
-19
View File
@@ -1,19 +0,0 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlmodel import Field, SQLModel
class Labeler(SQLModel, table=True):
__tablename__ = "labeler"
id: Optional[int] = Field(default=None, primary_key=True)
display_name: str
discord_username: str
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
nullable=False,
)
is_enabled: bool
notes: str
-19
View File
@@ -1,19 +0,0 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlmodel import Field, SQLModel
class Prompt(SQLModel, table=True):
__tablename__ = "prompt"
id: Optional[int] = Field(default=None, primary_key=True)
labeler_id: Optional[int] = Field(default=None, foreign_key="labeler.id")
prompt: str
response: Optional[str]
lang: Optional[str]
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
nullable=False,
)
-17
View File
@@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
from typing import Optional
from sqlmodel import Field, SQLModel
class ServiceClient(SQLModel, table=True):
__tablename__ = "service_client"
id: Optional[int] = Field(default=None, primary_key=True)
name: str
api_key: str
service_admin_email: Optional[str] = None
api_key: str
can_append: bool = True
can_write: bool = False
can_delete: bool = False
can_read: bool = True
-5
View File
@@ -1,5 +0,0 @@
# -*- coding: utf-8 -*-
from .labeler import Labeler, LabelerCreate, LabelerUpdate
from .prompt import Prompt, PromptCreate
__all__ = ["Labeler", "LabelerCreate", "LabelerUpdate", "Prompt", "PromptCreate"]
-28
View File
@@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
class Labeler(BaseModel):
id: int
discord_username: str
display_name: str
created_date: datetime
is_enabled: str
notes: Optional[str]
class LabelerCreate(BaseModel):
discord_username: str
display_name: Optional[str]
is_enabled: Optional[bool] = True
notes: Optional[str] = None
class LabelerUpdate(BaseModel):
discord_username: Optional[str] = None
display_name: Optional[str] = None
enabled: Optional[bool] = None
notes: Optional[str] = None
-22
View File
@@ -1,22 +0,0 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
class Prompt(BaseModel):
id: int
labeler_id: int
prompt: str
response: Optional[str]
lang: Optional[str]
created_date: datetime
class PromptCreate(BaseModel):
labeler_id: Optional[int] = None
discord_username: Optional[str] = None
prompt: str
response: Optional[str] = None
lang: Optional[str] = None
-14
View File
@@ -1,14 +0,0 @@
FROM python:3.9
WORKDIR /code
COPY ./requirements.txt /code/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY ./app /code/app
COPY ./app /app
ENV PYTHONPATH=/app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
-3
View File
@@ -1,3 +0,0 @@
FROM postgres:15
COPY ./scripts/create-db.sh /docker-entrypoint-initdb.d/
+3 -3
View File
@@ -4,9 +4,9 @@ from pathlib import Path
import alembic.command
import alembic.config
import fastapi
from app.api.v1.api import api_router
from app.config import settings
from loguru import logger
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from starlette.middleware.cors import CORSMiddleware
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
@@ -27,7 +27,7 @@ if settings.UPDATE_ALEMBIC:
def alembic_upgrade():
logger.info("Attempting to upgrade alembic on startup")
try:
alembic_ini_path = Path(__file__).parent.parent / "alembic.ini"
alembic_ini_path = Path(__file__).parent / "alembic.ini"
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
alembic_cfg.set_main_option("sqlalchemy.url", settings.DATABASE_URI)
alembic.command.upgrade(alembic_cfg, "head")
+57
View File
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.models import ApiClient
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN
def get_db() -> Generator:
with Session(engine) as db:
yield db
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
if api_key_query:
return api_key_query
else:
return api_key_header
def api_auth(
api_key: APIKey,
db: Session,
) -> ApiClient:
if api_key is not None:
if settings.ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
+6
View File
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
+259
View File
@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
import random
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.models.db_payload import TaskPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
match request.type:
case protocol_schema.TaskRequestType.random:
logger.info("Frontend requested a random task.")
while request.type == protocol_schema.TaskRequestType.random:
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
return generate_task(request)
case protocol_schema.TaskRequestType.summarize_story:
logger.info("Generating a SummarizeStoryTask.")
task = protocol_schema.SummarizeStoryTask(
story="This is a story. A very long story. So long, it needs to be summarized.",
)
case protocol_schema.TaskRequestType.rate_summary:
logger.info("Generating a RateSummaryTask.")
task = protocol_schema.RateSummaryTask(
full_text="This is a story. A very long story. So long, it needs to be summarized.",
summary="This is a summary.",
scale=protocol_schema.RatingScale(min=1, max=5),
)
case protocol_schema.TaskRequestType.initial_prompt:
logger.info("Generating an InitialPromptTask.")
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
task = protocol_schema.UserReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
)
)
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
task = protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, write me an English essay about water.",
is_assistant=False,
),
],
)
)
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
task = protocol_schema.RankInitialPromptsTask(
prompts=[
"Please write a story about a time you were happy.",
"Please write a story about a time you were sad.",
]
)
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
task = protocol_schema.RankUserRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
protocol_schema.ConversationMessage(
text="I'm not sure I understood correctly, could you rephrase that?",
is_assistant=True,
),
],
),
replies=[
"Oh come oooooon!",
"What are the news?",
],
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(
messages=[
protocol_schema.ConversationMessage(
text="Hey, assistant, what's going on in the world?",
is_assistant=False,
),
],
),
replies=[
"I'm not sure I understood correctly, could you rephrase that?",
"The world is fine. All good.",
"Crap is hitting the fan. Start farming.",
],
)
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid request type.",
)
logger.info(f"Generated {task=}.")
return task
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: protocol_schema.TaskRequest,
) -> Any:
"""
Create new task.
"""
api_client = deps.api_auth(api_key, db)
try:
task = generate_task(request)
pr = PromptRepository(db, api_client, request.user)
pr.store_task(task)
except Exception:
logger.exception("Failed to generate task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
return task
@router.post("/{task_id}/ack")
def acknowledge_task(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> Any:
"""
The frontend acknowledges a task.
"""
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=None)
# here we store the post id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
except Exception:
logger.exception("Failed to acknowledge task.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
return {}
@router.post("/{task_id}/nack")
def acknowledge_task_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> Any:
"""
The frontend reports failure to implement a task.
"""
deps.api_auth(api_key, db)
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
# here we would store the post id in the database for the task
return {}
@router.post("/interaction")
def post_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
interaction: protocol_schema.AnyInteraction,
) -> Any:
"""
The frontend reports an interaction.
"""
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToPost:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
)
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
return protocol_schema.TaskDone()
case protocol_schema.PostRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
)
# here we store the rating in the database
pr.store_rating(interaction)
return protocol_schema.TaskDone()
case protocol_schema.PostRanking:
logger.info(
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
)
# TODO: check if the ranking is valid
pr.store_ranking(interaction)
# here we would store the ranking in the database
return protocol_schema.TaskDone()
case _:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid response type.",
)
except Exception:
logger.exception("Interaction request failed.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
+45
View File
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, List, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
class Settings(BaseSettings):
PROJECT_NAME: str = "open-assistant backend"
API_V1_STR: str = "/api/v1"
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
POSTGRES_USER: str = "postgres"
POSTGRES_PASSWORD: str = "postgres"
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
ALLOW_ANY_API_KEY: bool = False
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
return v
return PostgresDsn.build(
scheme="postgresql",
user=values.get("POSTGRES_USER"),
password=values.get("POSTGRES_PASSWORD"),
host=values.get("POSTGRES_HOST"),
port=values.get("POSTGRES_PORT"),
path=f"/{values.get('POSTGRES_DB') or ''}",
)
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
UPDATE_ALEMBIC: bool = True
@validator("BACKEND_CORS_ORIGINS", pre=True)
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, (list, str)):
return v
raise ValueError(v)
settings = Settings(_env_file=".env")
+2
View File
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
__all__ = []
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from app.config import settings
from oasst_backend.config import settings
from sqlmodel import create_engine
if settings.DATABASE_URI is None:
+16
View File
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
from .api_client import ApiClient
from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .work_package import WorkPackage
__all__ = [
"ApiClient",
"Person",
"PersonStats",
"Post",
"PostReaction",
"WorkPackage",
]
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class ApiClient(SQLModel, table=True):
__tablename__ = "api_client"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
api_key: str = Field(max_length=512, index=True, unique=True)
description: str = Field(max_length=256)
admin_email: Optional[str] = Field(max_length=256, nullable=True)
enabled: bool = Field(default=True)
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
from typing import Literal
from oasst_backend.models.payload_column_type import payload_type
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@payload_type
class TaskPayload(BaseModel):
type: str
@payload_type
class SummarizationStoryPayload(TaskPayload):
type: Literal["summarize_story"] = "summarize_story"
story: str
@payload_type
class RateSummaryPayload(TaskPayload):
type: Literal["rate_summary"] = "rate_summary"
full_text: str
summary: str
scale: protocol_schema.RatingScale
@payload_type
class InitialPromptPayload(TaskPayload):
type: Literal["initial_prompt"] = "initial_prompt"
hint: str
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
conversation: protocol_schema.Conversation
hint: str | None
@payload_type
class AssistantReplyPayload(TaskPayload):
type: Literal["assistant_reply"] = "assistant_reply"
conversation: protocol_schema.Conversation
@payload_type
class PostPayload(BaseModel):
text: str
@payload_type
class ReactionPayload(BaseModel):
type: str
@payload_type
class RatingReactionPayload(ReactionPayload):
type: Literal["post_rating"] = "post_rating"
rating: str
@payload_type
class RankingReactionPayload(ReactionPayload):
type: Literal["post_ranking"] = "post_ranking"
ranking: list[int]
@payload_type
class RankConversationRepliesPayload(TaskPayload):
conversation: protocol_schema.Conversation # the conversation so far
replies: list[str]
@payload_type
class RankInitialPromptsPayload(TaskPayload):
"""A task to rank a set of initial prompts."""
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
prompts: list[str]
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
@payload_type
class RankAssistantRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of assistant replies to a conversation."""
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
import json
from typing import Any, Generic, Type, TypeVar
import sqlalchemy.dialects.postgresql as pg
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, parse_obj_as, validator
from pydantic.main import ModelMetaclass
from sqlalchemy.types import TypeDecorator
payload_type_registry = {}
P = TypeVar("P", bound=BaseModel)
def payload_type(cls: Type[P]) -> Type[P]:
payload_type_registry[cls.__name__] = cls
return cls
class PayloadContainer(BaseModel):
payload_type: str = ""
payload: BaseModel = None
def __init__(self, **v):
p = v["payload"]
if isinstance(p, dict):
t = v["payload_type"]
if t not in payload_type_registry:
raise RuntimeError(f"Payload type '{t}' not registered")
cls = payload_type_registry[t]
v["payload"] = cls(**p)
super().__init__(**v)
@validator("payload", pre=True)
def check_payload(cls, v: BaseModel, values: dict[str, Any]) -> BaseModel:
values["payload_type"] = type(v).__name__
return v
class Config:
orm_mode = True
T = TypeVar("T")
def payload_column_type(pydantic_type):
class PayloadJSONBType(TypeDecorator, Generic[T]):
impl = pg.JSONB()
def __init__(
self,
json_encoder=json,
):
self.json_encoder = json_encoder
super(PayloadJSONBType, self).__init__()
# serialize
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.json_encoder.dumps
def process(value: T):
if value is not None:
if isinstance(pydantic_type, ModelMetaclass):
# This allows to assign non-InDB models and if they're
# compatible, they're directly parsed into the InDB
# representation, thus hiding the implementation in the
# background. However, the InDB model will still be returned
value_to_dump = pydantic_type.from_orm(value)
else:
value_to_dump = value
value = jsonable_encoder(value_to_dump)
if impl_processor:
return impl_processor(value)
else:
return dumps(jsonable_encoder(value_to_dump))
return process
# deserialize
def result_processor(self, dialect, coltype) -> T:
impl_processor = self.impl.result_processor(dialect, coltype)
def process(value):
if impl_processor:
value = impl_processor(value)
if value is None:
return None
# Explicitly use the generic directly, not type(T)
full_obj = parse_obj_as(pydantic_type, value)
return full_obj
return process
def compare_values(self, x, y):
return x == y
return PayloadJSONBType
+26
View File
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class Person(SQLModel, table=True):
__tablename__ = "person"
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
username: str = Field(nullable=False, max_length=128)
auth_method: str = Field(nullable=False, max_length=128, default="local")
display_name: str = Field(nullable=False, max_length=256)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
api_client_id: UUID = Field(foreign_key="api_client.id")
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class PersonStats(SQLModel, table=True):
__tablename__ = "person_stats"
person_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
)
leader_score: int = 0
modified_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
reactions: int = 0 # reactions sent by user
posts: int = 0 # posts sent by user
upvotes: int = 0 # received upvotes (form other users)
downvotes: int = 0 # received downvotes (from other users)
work_reward: int = 0 # reward for workpackage completions
compare_wins: int = 0 # num times user's post won compare tasks
compare_losses: int = 0 # num times users's post lost compare tasks
+33
View File
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class Post(SQLModel, table=True):
__tablename__ = "post"
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
parent_id: UUID = Field(nullable=True)
thread_id: UUID = Field(nullable=False, index=True)
workpackage_id: UUID = Field(nullable=True, index=True)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
role: str = Field(nullable=False, max_length=128)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_post_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class PostReaction(SQLModel, table=True):
__tablename__ = "post_reaction"
post_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=False, primary_key=True)
)
person_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class WorkPackage(SQLModel, table=True):
__tablename__ = "work_package"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
+316
View File
@@ -0,0 +1,316 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
self.db = db
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
return None
person: Person = (
self.db.query(Person)
.filter(
Person.api_client_id == self.api_client.id,
Person.username == user.id,
Person.auth_method == user.auth_method,
)
.first()
)
if person is None:
# user is unknown, create new record
person = Person(
username=user.id,
display_name=user.display_name,
api_client_id=self.api_client.id,
auth_method=user.auth_method,
)
self.db.add(person)
self.db.commit()
self.db.refresh(person)
elif user.display_name and user.display_name != person.display_name:
# we found the user but the display name changed
person.display_name = user.display_name
self.db.add(person)
self.db.commit()
return person
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise TypeError(f"post_id must be string, not {type(post_id)}")
if not post_id:
raise ValueError("post_id must not be empty")
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise KeyError(f"WorkPackage for task {task_id} not found")
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
raise RuntimeError("WorkPackage already expired.")
# ToDo: check race-condition, transaction
# check if task thread exits
thread_root = (
self.db.query(Post)
.filter(
Post.workpackage_id == work_pack.id,
Post.frontend_post_id == post_id,
Post.parent_id is None,
Post.api_client_id == self.api_client.id,
)
.one_or_none()
)
if thread_root is None:
thread_id = uuid4()
thread_root = self.insert_post(
post_id=thread_id,
thread_id=thread_id,
frontend_post_id=post_id,
parent_id=None,
role="system",
workpackage_id=work_pack.id,
payload=None,
payload_type="bind",
)
return thread_root
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
post: Post = (
self.db.query(Post)
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
.one_or_none()
)
if fail_if_missing and post is None:
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
return post
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
return work_pack
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
# find post with post-id
parent_post: Post = (
self.db.query(Post)
.filter(
Post.api_client_id == self.api_client.id,
Post.frontend_post_id == reply.post_id,
# Post.person_id == self.person_id
)
.one_or_none()
)
if parent_post is None:
raise KeyError(f"Post for post_id {reply.post_id} not found.")
# create reply post
user_post_id = uuid4()
user_post = self.insert_post(
post_id=user_post_id,
frontend_post_id=reply.user_post_id,
parent_id=parent_post.id,
thread_id=parent_post.thread_id,
workpackage_id=parent_post.workpackage_id,
role=role,
payload=db_payload.PostPayload(text=reply.text),
)
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
)
match type(work_payload):
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(work_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case _:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
)
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
case _:
raise ValueError(f"Invalid task type: {type(task)=}")
wp = self.insert_work_package(payload=payload, id=task.id)
assert wp.id == task.id
return wp
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
id=id,
person_id=self.person_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
)
self.db.add(wp)
self.db.commit()
self.db.refresh(wp)
return wp
def insert_post(
self,
*,
post_id: UUID,
frontend_post_id: str,
parent_id: UUID,
thread_id: UUID,
workpackage_id: UUID,
role: str,
payload: db_payload.PostPayload,
payload_type: str = None,
) -> Post:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
post = Post(
id=post_id,
parent_id=parent_id,
thread_id=thread_id,
workpackage_id=workpackage_id,
person_id=self.person_id,
role=role,
frontend_post_id=frontend_post_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
)
self.db.add(post)
self.db.commit()
self.db.refresh(post)
return post
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
raise ValueError("User required")
container = PayloadContainer(payload=payload)
reaction = PostReaction(
post_id=post_id,
person_id=self.person_id,
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
)
self.db.add(reaction)
self.db.commit()
self.db.refresh(reaction)
return reaction
-6
View File
@@ -1,6 +0,0 @@
#!/bin/bash
set -e
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
CREATE DATABASE ocgpt_backend;
EOSQL
-5
View File
@@ -1,5 +0,0 @@
#!/usr/bin/env bash
export DATABASE_URI=postgresql://postgres:postgres@localhost:5432/postgres
uvicorn app.main:app --reload
-13
View File
@@ -1,13 +0,0 @@
install:
python -m pip install -U pip
python -m pip install -e .
lint: ## [Local development] Run pylint and black
python -m pylint app
python -m black --check -l 120 app
black: ## [Local development] Auto-format python code using black
python -m black -l 120 .
run:
python -m bot
-14
View File
@@ -1,14 +0,0 @@
# open-chat-gpt
This is the github repo for the open-chat-gpt project.
We are currently building a discord bot in order to make everyone contribute with great prompts and answers.
Join us!
https://discord.gg/ZUfPw6jP
## Project description
We are calling the community for help to collect ChatGPT-like Instruction-Fulfillment datasamples via Discord. People can post Instructions they think would make sense for ChatGPT-like systems & also provide a good reference answer for it.
## Todo
Figure out ouath flow for the app to work inside the open-chat-gpt testing channel here. https://discord.gg/JJSKtRhv
-207
View File
@@ -1,207 +0,0 @@
# -*- coding: utf-8 -*-
import json
import os
import discord
import requests
from discord import app_commands
from dotenv import load_dotenv
from loguru import logger
bot_url = "https://discord.com/api/oauth2/authorize?client_id=1051614245940375683&permissions=8&scope=bot"
# Load up all the important environment variables.
load_dotenv()
# For authentication.
TOKEN = os.getenv("DISCORD_TOKEN")
# For Backends.
API_SERVER_URL = os.getenv("API_SERVER_URL")
API_SERVER_KEY = os.getenv("API_SERVER_KEY")
labelers_url = f"{API_SERVER_URL}/api/v1/labelers/"
prompts_url = f"{API_SERVER_URL}/api/v1/prompts/"
headers = {"X-API-Key": API_SERVER_KEY}
# For testing only.
TEST_GUILD = os.getenv("TEST_GUILD")
# Initiate the client and command tree to create slash commands.
class OpenChatGPTClient(discord.Client):
def __init__(self, *, intents: discord.Intents):
super().__init__(intents=intents)
self.tree = app_commands.CommandTree(self)
async def setup_hook(self):
if TEST_GUILD:
# When testing the bot it's handy to run in a single server (called a
# Guide in the API). This is relatively fast.
guild = discord.Object(id=TEST_GUILD)
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
else:
# This can take up to an hour for the commands to be registered.
await self.tree.sync()
logger.debug("Ready!")
# List the set of intents needed for commands to operate properly.
intents = discord.Intents.default()
intents.message_content = True
client = OpenChatGPTClient(intents=intents)
class LikeButton(discord.ui.Button):
def __init__(self, label, channel, username, prompt):
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👍")
self.channel = channel
self.username = username
self.prompt = prompt
async def callback(self, interaction):
# interaction holds the interaction object
# await interaction.response.defer()
await interaction.response.send_message("Thanks for your feedback. You liked this 👍 ")
class NeutralButton(discord.ui.Button):
def __init__(self, label, channel, username, prompt):
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="😐")
self.channel = channel
self.username = username
self.prompt = prompt
async def callback(self, interaction):
# interaction holds the interaction object
# await interaction.response.defer()
await interaction.response.send_message("Thanks for your feedback. You thought this was neutral 😐 ")
class DislikeButton(discord.ui.Button):
def __init__(self, label, channel, username, prompt):
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👎")
self.channel = channel
self.username = username
self.prompt = prompt
async def callback(self, interaction):
# interaction holds the interaction object
# await interaction.response.defer()
# send the feedback to the backend #
await interaction.response.send_message("Thanks for your feedback. You disliked this 👎 ")
@client.tree.command()
async def register(interaction: discord.Interaction):
"""Registers the user for submissions."""
labeler = {
"discord_username": f"{interaction.user.id}",
"display_name": interaction.user.name,
"is_enabled": True,
}
response = requests.post(labelers_url, headers=headers, json=labeler)
if response.status_code == 200:
await interaction.response.send_message(f"Added you {interaction.user.name}")
else:
logger.debug(response)
await interaction.response.send_message("Failed to add you")
@client.tree.command()
async def list_participants(interaction: discord.Interaction):
"""Reports the set of registered participants."""
response = requests.get(labelers_url, headers=headers)
if response.status_code == 200:
names = ",".join([labeler["display_name"] for labeler in response.json()])
await interaction.response.send_message(f"Found these users: {names}")
else:
await interaction.response.send_message("Failed to fetch participants")
async def send_prompt_with_response_and_button(channel, username, prompt, response):
await channel.send(f"What do you think about the following interaction: \nprompt: {prompt} \nresponse: {response}")
# await channel.send(f'Please click on the button that best describes your reaction to the response:')
# add buttons
view = discord.ui.View()
like = LikeButton(label="Like", channel=channel, username=username, prompt=prompt)
neutral = NeutralButton(label="Neutral", channel=channel, username=username, prompt=prompt)
dislike = DislikeButton(label="Dislike", channel=channel, username=username, prompt=prompt)
view.add_item(item=like)
view.add_item(item=neutral)
view.add_item(item=dislike)
await channel.send(view=view)
@client.tree.command()
async def review_prompts(interaction: discord.Interaction, number_of_prompts: int):
# get the prompt from the db
url = f"{prompts_url}?begin_id=0&limit={number_of_prompts}"
response = requests.get(url, headers=headers)
if response.status_code == 200:
prompts = response.json()
logger.debug("the responses are:", prompts)
for prompt in prompts:
await send_prompt_with_response_and_button(
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
)
else:
await interaction.response.send_message("Failed to get prompts for review")
@client.tree.command()
async def add_prompt(interaction: discord.Interaction, prompt: str, response: str, language: str = "en"):
"""Uploads a single prompt to the server."""
prompt = {
"discord_username": f"{interaction.user.id}",
"labeler_id": 5,
"prompt": prompt,
"response": response,
"lang": language,
}
response = requests.post(prompts_url, headers=headers, json=prompt)
if response.status_code == 200:
await send_prompt_with_response_and_button(
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
)
# send the prompt back with buttons for the user to click on
# await interaction.response.send_message("Added your prompt")
else:
await interaction.response.send_message("Failed to add the prompt")
@client.tree.command()
async def add_prompts_set(interaction: discord.Interaction, prompts: discord.Attachment):
"""Uploads a batch of prompts to the server."""
# Loading a bunch of prompts from a file can take a while. So first defer
# the response to ensure we're able to later tell the user what happened.
await interaction.response.defer(ephemeral=True)
# Read the prompts and load them one by one.
# TODO: Upload a batch when the API supports it.
# TODO: Handle incorrect file types and parsing errors.
prompts_raw = await prompts.read()
prompts_loaded = json.loads(prompts_raw)
count = 0
for entry in prompts_loaded:
for response in entry["responses"]:
prompt = {
"discord_username": f"{interaction.user.id}",
"labeler_id": 5,
"prompt": entry["prompt"],
"response": response,
"lang": "en",
}
response = requests.post(prompts_url, headers=headers, json=prompt)
if response.status_code != 200:
await interaction.followup.send("Failed to upload")
return
count += 1
await interaction.followup.send(f"Loaded up {count} prompts")
client.run(TOKEN)
-2
View File
@@ -1,2 +0,0 @@
discord.py==2.1.0
python-dotenv==0.21.0
-29
View File
@@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
from setuptools import find_packages, setup
if __name__ == "__main__":
import os
def _read_reqs(relpath):
fullpath = os.path.join(os.path.dirname(__file__), relpath)
with open(fullpath) as f:
return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
REQUIREMENTS = _read_reqs("requirements.txt")
setup(
name="open-chat-gpt",
packages=find_packages(),
version="0.0.1",
license="Apache 2.0",
description="A Discord Bot for collecting and ranking prompts to train an Open ChatGPT",
keywords=["machine learning", "natural language processing", "discord"],
install_requires=REQUIREMENTS,
classifiers=[
"Development Status :: Alpha",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache License",
"Programming Language :: Python :: 3.6",
],
)
-10
View File
@@ -1,10 +0,0 @@
[
{
"prompt": "tell me the name of two dogs",
"responses": ["Charles", "bobby"]
},
{
"prompt": "Name one type of cheese made in france",
"responses": ["Munster", "Gouda"]
}
]
+20
View File
@@ -0,0 +1,20 @@
# Open-Assistant Data Collection Discord Bot
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
## Invite official bot
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
## Bot token for development
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
The simplest way to configure the token is via an `.env` file:
```
BOT_TOKEN=XYZABC123...
```
+17
View File
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from bot import OpenAssistantBot
from bot_settings import settings
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
if __name__ == "__main__":
bot = OpenAssistantBot(
settings.BOT_TOKEN,
bot_channel_name=settings.BOT_CHANNEL_NAME,
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
+74
View File
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
import enum
from typing import Optional, Type
import requests
from oasst_shared.schemas import protocol as protocol_schema
class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class ApiClient:
def __init__(self, backend_url: str, api_key: str):
self.backend_url = backend_url
self.api_key = api_key
task_models_map: dict[str, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
self.task_models_map = task_models_map
def post(self, path: str, json: dict) -> dict:
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return response.json()
def _parse_task(self, data: dict) -> protocol_schema.Task:
if not isinstance(data, dict):
raise ValueError("dict expected")
task_type = data.get("type")
if task_type not in self.task_models_map:
raise RuntimeError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data)
def fetch_task(
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
) -> protocol_schema.Task:
req = protocol_schema.TaskRequest(type=task_type, user=user)
data = self.post("/api/v1/tasks/", req.dict())
return self._parse_task(data)
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
def ack_task(self, task_id: str, post_id: str) -> None:
req = protocol_schema.TaskAck(post_id=post_id)
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
def nack_task(self, task_id: str, reason: str) -> None:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
+283
View File
@@ -0,0 +1,283 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)
+61
View File
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from abc import ABC
from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@dataclass
class ReplyHandlerInfo:
msg_id: int
handler_task: asyncio.Task
handler: ChannelHandlerBase
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
bot_channel: discord.TextChannel
templates: MessageTemplates
reply_handlers: dict[int, ReplyHandlerInfo]
def __init__(self):
self.reply_handlers = {} # handlers by msg_id
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
) -> discord.Message:
if channel is None:
self.ensure_bot_channel()
channel = self.bot_channel
return await channel.send(content=content, view=view)
async def post_template(
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
) -> discord.Message:
logger.debug(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view=view, channel=channel)
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
if msg_id in self.reply_handlers:
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
task.add_done_callback(lambda t: handler.on_completed())
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
+15
View File
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from pydantic import AnyHttpUrl, BaseSettings
class BotSettings(BaseSettings):
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
API_KEY: str = "any_key"
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
+88
View File
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
import discord
from loguru import logger
class ChannelExpiredException(Exception):
pass
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool = False
expiry_date: datetime
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.queue = asyncio.Queue()
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
if self.expired:
raise ChannelExpiredException()
msg = await self.queue.get()
if msg is None:
if self.expired:
raise ChannelExpiredException()
else:
raise RuntimeError("Unexpected None message read")
return msg
def on_reply(self, message: discord.Message) -> None:
self.queue.put_nowait(message)
def on_expire(self) -> None:
logger.info("ChannelHandler: on_expire")
self.expired = True
self.queue.put_nowait(None)
def on_completed(self) -> None:
logger.info("ChannelHandler: on_completed")
self.completed = True
def tick(self, now: datetime):
if now > self.expiry_date and not self.expired:
self.on_expire()
@abstractmethod
async def handler_loop(self):
...
async def finalize(self):
pass
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
async def read(self) -> discord.Message:
try:
return await super().read()
except ChannelExpiredException:
await self.cleanup()
raise
async def cleanup(self):
logger.debug("AutoDestructThreadHandler.cleanup")
if self.thread:
logger.debug(f"deleting thread: {self.thread.name}")
await self.thread.delete()
self.thread = None
if self.first_message:
logger.debug(f"deleting first_message: {self.first_message.content}")
await self.first_message.delete()
self.first_message = None
async def finalize(self):
await self.cleanup()
return await super().finalize()
+18
View File
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
import jinja2
from loguru import logger
class MessageTemplates:
def __init__(self, template_dir="./templates"):
self.env = jinja2.Environment(
loader=jinja2.FileSystemLoader(template_dir),
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
)
def render(self, template_name, **kwargs):
template = self.env.get_template(template_name)
txt = template.render(kwargs)
logger.debug(txt)
return txt
+7
View File
@@ -0,0 +1,7 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
+267
View File
@@ -0,0 +1,267 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
name = discord.ui.TextInput(label="Name")
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
pass
@abstractmethod
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_teaser_msg(self, template_name: str):
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
return await self.bot.post_template(
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToPost(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_text_reply_to_post()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.PostRanking(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
ranking=ranking,
)
)
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
ranking_str = user_msg.content
ranking = [int(x) - 1 for x in ranking_str.split(",")]
self.post_ranking(user_msg, ranking=ranking)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_ranking()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_summarize_story.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
task: protocol_schema.InitialPromptTask
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_initial_prompt.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
task: protocol_schema.UserReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_user_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
task: protocol_schema.AssistantReplyTask
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_assistant_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class RankInitialPromptsHandler(ChannelTaskBase):
task: protocol_schema.RankInitialPromptsTask
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RankConversationsHandler(ChannelTaskBase):
task: protocol_schema.RankConversationRepliesTask
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class RateSummaryHandler(ChannelTaskBase):
task: protocol_schema.RateSummaryTask
thread_name: str = "Ratings"
async def _rating_response_handler(self, score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
if self.thread:
try:
self.backend.post_interaction(
protocol_schema.PostRating(
post_id=str(self.first_message.id),
user_post_id=str(interaction.id),
user=self.to_api_user(interaction.user),
rating=score,
)
)
await interaction.response.send_message(
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
)
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in _rating_response_handler()")
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rate_summary.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
logger.info(f"on_rate_summary_reply: {msg.content}")
await msg.add_reaction("")
await msg.reply("❌ Text intput not supported.")
+13
View File
@@ -0,0 +1,13 @@
```
________ __
\_____ \ _____ ______ _______/ |_
/ | \\__ \ / ___// ___/\ __\
/ | \/ __ \_\___ \ \___ \ | |
\_______ (____ /____ >____ > |__|
\/ \/ \/ \/
{{bot_name}} {{version}}
git hash: {{git_hash}}
debug_mode: {{debug}}
```
https://github.com/LAION-AI/Open-Assistant
+15
View File
@@ -0,0 +1,15 @@
**Open-Assistant Bot Help**
Available slash-commands:
`/work` Requests a new personalized human feedback task
`/help` Show this message
{% if is_bot_owner %}
Commands for bot owners:
`!sync`
`!sync.guild`
`!sync.copy_global`
`!sync.clear_guild`
{% endif %}
@@ -0,0 +1,12 @@
Act as the assistant and reply to the user.
Here is the conversation so far:
{% for message in task.conversation.messages %}
{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}
{% endfor %}
:robot: Assistant: { human, pls help me! ... }
@@ -0,0 +1,4 @@
Please provide an initial prompt to the assistant.
{% if task.hint is not none %}
Hint: {{task.hint}}
{% endif %}
@@ -0,0 +1,13 @@
Here is the conversation so far:
{% for message in task.conversation.messages %}{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}{% endfor %}
Rank the following replies:
{% for reply in task.replies %}
{{loop.index}}: {{reply}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
@@ -0,0 +1,5 @@
Rank the following prompts:
{% for prompt in task.prompts %}
{{loop.index}}: {{prompt}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
@@ -0,0 +1,7 @@
Rate the following summary:
{{task.summary}}
Full text:
{{task.full_text}}
Rating scale: {{task.scale.min}} - {{task.scale.max}}
@@ -0,0 +1,2 @@
Summarize to the following story:
{{task.story}}
+12
View File
@@ -0,0 +1,12 @@
Please provide a reply to the assistant.
Here is the conversation so far:
{% for message in task.conversation.messages %}{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}{% endfor %}
{% if task.hint %}
Hint: {{ task.hint }}
{% endif %}
@@ -0,0 +1,3 @@
:robot: **Challenge: Assistant Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:microphone2: **Challenge: Initial Prompt**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:bar_chart: **Challenge: Rank Replies**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:bar_chart: **Challenge: Rank Initial Prompts**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:ballot_box: **Challenge: Rate Summary**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:books: **Challenge: Summarize Story**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:person_red_hair: **Challenge: User Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+6
View File
@@ -0,0 +1,6 @@
Hi there,
I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗!
Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world.
Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands.
+52
View File
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)
+15
View File
@@ -0,0 +1,15 @@
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10
COPY ./backend/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
ENV PORT 8080
COPY ./oasst-shared /oasst-shared
RUN pip install -e /oasst-shared
COPY ./backend/alembic /app/alembic
COPY ./backend/alembic.ini /app/alembic.ini
COPY ./backend/main.py /app/main.py
COPY ./backend/oasst_backend /app/oasst_backend
+7
View File
@@ -0,0 +1,7 @@
FROM python:3.10-slim-bullseye
RUN mkdir /app
COPY ./discord-bot/requirements.txt /requirements.txt
RUN pip install -r requirements.txt
WORKDIR /app
COPY ./discord-bot /app
CMD ["python", "bot.py"]
+62
View File
@@ -0,0 +1,62 @@
# Install dependencies only when needed
FROM node:16.19 AS deps
# Check https://github.com/nodejs/docker-node/tree/b4117f9333da4138b03a546ec926ef50a31506c3#nodealpine to understand why libc6-compat might be needed.
# RUN apk add --no-cache libc6-compat
WORKDIR /app
# Install dependencies based on the preferred package manager
COPY ./website/package.json ./website/package-lock.json ./
RUN \
if [ -f package-lock.json ]; then npm ci; \
else echo "Lockfile not found." && exit 1; \
fi
# Rebuild the source code only when needed
FROM node:16.19 AS builder
WORKDIR /app
COPY --from=deps /app/node_modules ./node_modules
COPY ./website/ .
# Next.js collects completely anonymous telemetry data about general usage.
# Learn more here: https://nextjs.org/telemetry
# Uncomment the following line in case you want to disable telemetry during the build.
# ENV NEXT_TELEMETRY_DISABLED 1
# RUN yarn build
RUN npx prisma generate
RUN npm run build
# Production image, copy all the files and run next
FROM node:16.19 AS runner
WORKDIR /app
ENV NODE_ENV production
# Uncomment the following line in case you want to disable telemetry during runtime.
# ENV NEXT_TELEMETRY_DISABLED 1
RUN addgroup --system --gid 1001 nodejs
RUN adduser --system --uid 1001 nextjs
COPY --from=builder /app/public ./public
# Copy over the prisma schema so we can to `npx prisma db push` and ensure the
# database exists on startup.
COPY --chown=nextjs:nodejs ./website/prisma/schema.prisma ./
# Copy over a startup script that'll run `npx prisma db push` before starting
# the webserver. This ensures the webserver can actually check user accounts.
# This is a prisma variant of the postgres solution suggested in
# https://docs.docker.com/compose/startup-order/
COPY --chown=nextjs:nodejs ./website/wait-for-postgres.sh ./
# Automatically leverage output traces to reduce image size
# https://nextjs.org/docs/advanced-features/output-file-tracing
COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone ./
COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
USER nextjs
EXPOSE 3000
ENV PORT 3000
CMD ["node", "server.js"]
+3
View File
@@ -0,0 +1,3 @@
# Shared Python code for Open Assisstant
Run `pip install -e .` to install the package in editable mode.

Some files were not shown because too many files have changed in this diff Show More