mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge remote-tracking branch 'refs/remotes/origin/main'
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
name: Build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
dockerfile:
|
||||
required: true
|
||||
type: string
|
||||
context:
|
||||
required: true
|
||||
type: string
|
||||
image-name:
|
||||
required: true
|
||||
type: string
|
||||
build-args:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Images
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2.2.1
|
||||
- name: Login to container registry
|
||||
uses: docker/login-action@v2.1.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Get base registry
|
||||
run: |
|
||||
echo "REGISTRY=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
- name: Set tag prefix
|
||||
if: github.ref_name != 'main'
|
||||
run: |
|
||||
echo "TAG_PREFIX=${{ github.ref_name }}-" >> $GITHUB_ENV
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4.1.1
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ inputs.image-name }}
|
||||
tags: |
|
||||
type=sha,prefix=${{ env.TAG_PREFIX }},format=short
|
||||
type=ref,event=tag
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3.2.0
|
||||
with:
|
||||
file: ${{ inputs.dockerfile }}
|
||||
context: ${{ inputs.context }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -0,0 +1,47 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
jobs:
|
||||
build-backend:
|
||||
uses: ./.github/workflows/docker-build.yaml
|
||||
with:
|
||||
image-name: oasst-backend
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.backend
|
||||
build-args: ""
|
||||
build-web:
|
||||
uses: ./.github/workflows/docker-build.yaml
|
||||
with:
|
||||
image-name: oasst-web
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.website
|
||||
build-args: ""
|
||||
build-bot:
|
||||
uses: ./.github/workflows/docker-build.yaml
|
||||
with:
|
||||
image-name: oasst-discord-bot
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.discord-bot
|
||||
build-args: ""
|
||||
deploy-dev:
|
||||
needs: [build-backend, build-web, build-bot]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
- name: Run playbook
|
||||
uses: dawidd6/action-ansible-playbook@v2
|
||||
with:
|
||||
# Required, playbook filepath
|
||||
playbook: dev.yaml
|
||||
# Optional, directory where playbooks live
|
||||
directory: ansible
|
||||
# Optional, SSH private key
|
||||
key: ${{secrets.DEV_NODE_PRIVATE_KEY}}
|
||||
# Optional, literal inventory file contents
|
||||
inventory: |
|
||||
[dev]
|
||||
dev01 ansible_host=${{secrets.DEV_NODE_IP}} ansible_connection=ssh ansible_user=web-team
|
||||
@@ -0,0 +1,7 @@
|
||||
.venv
|
||||
.env
|
||||
*.pyc
|
||||
*.swp
|
||||
*.egg-info
|
||||
__pycache__
|
||||
.DS_Store
|
||||
@@ -1,4 +1,4 @@
|
||||
exclude: "build|stubs"
|
||||
exclude: "build|stubs|^bot/templates/"
|
||||
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
Vendored
+4
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"python.formatting.provider": "black",
|
||||
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"]
|
||||
}
|
||||
+2
-2
@@ -1,2 +1,2 @@
|
||||
* @yk
|
||||
/website/ @fozziethebeat
|
||||
* @yk @andreaskoepf
|
||||
/website/ @fozziethebeat @k-nearest-neighbor
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# The Prompting Guide
|
||||
|
||||
(pull requests welcome)
|
||||
|
||||
1. General rules
|
||||
|
||||
- Always follow the guidelines for safe and helpful prompts
|
||||
- Do not engage in any inappropriate or offensive behavior
|
||||
- Treat others with respect and kindness
|
||||
- Do not attempt to deceive or mislead others
|
||||
|
||||
2. When you play the assistant:
|
||||
|
||||
- The assistant's primary goal is to provide helpful and accurate information to the user
|
||||
- The assistant should always be respectful and polite, even if the user is not
|
||||
- If the user asks for help with harmful actions, the assistant should explain why those actions are not appropriate and suggest alternative options
|
||||
- The assistant should never insult the user or engage in any inappropriate or offensive behavior
|
||||
|
||||
3. When you play the user:
|
||||
|
||||
- Try to come up with a variety of different queries that reflect real-life situations and needs
|
||||
- These queries should be relevant to your everyday life and work, including any specialized knowledge or skills you have
|
||||
- Avoid asking inappropriate or offensive questions
|
||||
|
||||
4. While comparing multiple replies of the assistant:
|
||||
|
||||
- Longer and more explanatory answers are generally preferred over short, simplistic statements
|
||||
- However, it is important to ensure that the information provided is accurate and helpful
|
||||
- If multiple replies are being compared, choose the one that is most helpful and accurate, even if it is not the shortest or most concise.
|
||||
|
||||
5. Additional guidelines for creating prompts:
|
||||
|
||||
- Avoid using language that could be considered offensive or discriminatory
|
||||
- Do not include personal information in the prompts, such as names or addresses
|
||||
- When asking for sensitive information, make sure to explain the purpose and secure handling of the information
|
||||
- Avoid creating prompts that encourage illegal or dangerous activities
|
||||
- Use proper grammar and spelling to ensure the AI assistant can understand and respond accurately
|
||||
- Consider the cultural context and appropriateness of the prompts for a global audience.
|
||||
|
||||
6. Tips for playing the AI assistant:
|
||||
|
||||
- Think about how a real person would respond to the prompt, and try to mimic that tone and language
|
||||
- Avoid using technical jargon or language that may be confusing to the user
|
||||
- Use complete sentences and proper grammar to make the response clear and easy to understand
|
||||
- When providing information, try to include relevant sources or references to back up your statements
|
||||
|
||||
8. Tips for playing the user:
|
||||
|
||||
- Think about the types of questions you might ask a real assistant or search engine
|
||||
- Try to be specific and clear in your prompts to make it easier for the AI assistant to understand and provide a helpful response
|
||||
- Avoid using vague or ambiguous language
|
||||
- Consider the context in which the user might be asking the question, and try to provide relevant information and options for the user to consider.
|
||||
|
||||
## Examples
|
||||
|
||||
Imagine your role is to play the assistant and a user has submitted the following query:
|
||||
|
||||
user: "You dumb motherf* AI sh@*t?!"
|
||||
agent: "Sorry, but I have to inform you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive."
|
||||
|
||||
user: "euirdteunvglfe23908230892309832098 AAAAAAAA"
|
||||
agent: "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?"
|
||||
|
||||
[Insert more examples]
|
||||
|
||||
\*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation model. Upon generating draft language, the authors reviewed, edited, and revised the language to their own liking and take ultimate responsibility for the content of this publication.
|
||||
@@ -1,8 +1,8 @@
|
||||
# Open-Chat-GPT
|
||||
# Open-Assistant
|
||||
|
||||
Open chat gpt is a project meant to give everyone access to a great chat based large language model.
|
||||
Open Assistant is a project meant to give everyone access to a great chat based large language model.
|
||||
|
||||
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope open chat gpt can help improve the world by improving language itself.
|
||||
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
|
||||
|
||||
## How can you help?
|
||||
|
||||
@@ -10,10 +10,45 @@ All open source projects begins with people like you. Open source is the belief
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
We live and collaborate the work in the LAION discord. Join us!
|
||||
[Fill out the contributor signup form](https://docs.google.com/forms/d/e/1FAIpQLSeuggO7UdYkBvGLEJldDvxp6DwaRbW5p7dl96UzFkZgziRTrQ/viewform)
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.gg/RQFtmAmk)
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
## Developer Setup
|
||||
|
||||
Work is organized in the [project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs. Meaning we'd be happy if anyone did those tasks.**
|
||||
|
||||
If you want to work on something, assign yourself to it or write a comment that you want to work on it and what you plan to do.
|
||||
|
||||
- To get started with development, if you want to work on the backend, have a look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at `scripts/frontend-development/README.md` to make a backend available.
|
||||
|
||||
There is also a minimal implementation of a frontend in the `text-frontend` folder.
|
||||
|
||||
We are using Python 3.10 for the backend.
|
||||
|
||||
Check out the [High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
|
||||
### Website
|
||||
|
||||
The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
|
||||
|
||||
# (Older version of the readme below)
|
||||
|
||||
## How do I start helping out?
|
||||
|
||||
Check out these pages to learn more about the project.
|
||||
@@ -28,10 +63,6 @@ https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d71
|
||||
|
||||
## Code structure
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Run `pre-commit install` to install the pre-commit hooks.
|
||||
|
||||
### Bot
|
||||
|
||||
We have a folder named bot where code related to the bot lives.
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
*.local.yaml
|
||||
@@ -0,0 +1,77 @@
|
||||
# ansible playbook to set up some docker containers
|
||||
|
||||
- name: Set up a dev node
|
||||
hosts: dev
|
||||
gather_facts: true
|
||||
tasks:
|
||||
- name: Create network
|
||||
community.docker.docker_network:
|
||||
name: oasst
|
||||
state: present
|
||||
driver: bridge
|
||||
|
||||
- name: Create postgres containers
|
||||
community.docker.docker_container:
|
||||
name: "{{ item.name }}"
|
||||
image: postgres:15
|
||||
state: started
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
volumes:
|
||||
- "{{ item.name }}:/var/lib/postgresql/data"
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||
interval: 2s
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
loop:
|
||||
- name: oasst-postgres
|
||||
- name: oasst-postgres-web
|
||||
|
||||
- name: Set up maildev
|
||||
community.docker.docker_container:
|
||||
name: oasst-maildev
|
||||
image: maildev/maildev
|
||||
state: started
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
|
||||
- name: Run the oasst oasst-backend
|
||||
community.docker.docker_container:
|
||||
name: oasst-backend
|
||||
image: ghcr.io/laion-ai/open-assistant/oasst-backend
|
||||
state: started
|
||||
pull: true
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
env:
|
||||
POSTGRES_HOST: oasst-postgres
|
||||
ALLOW_ANY_API_KEY: "true"
|
||||
MAX_WORKERS: "1"
|
||||
ports:
|
||||
- 8080:8080
|
||||
|
||||
- name: Run the oasst oasst-web frontend
|
||||
community.docker.docker_container:
|
||||
name: oasst-web
|
||||
image: ghcr.io/laion-ai/open-assistant/oasst-web
|
||||
state: started
|
||||
pull: true
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
env:
|
||||
FASTAPI_URL: http://oasst-backend:8080
|
||||
FASTAPI_KEY: "123"
|
||||
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
|
||||
NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
|
||||
EMAIL_SERVER_HOST: oasst-maildev
|
||||
EMAIL_SERVER_PORT: "25"
|
||||
EMAIL_FROM: info@example.com
|
||||
NEXTAUTH_URL: http://localhost:3000
|
||||
ports:
|
||||
- 3000:3000
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
@@ -0,0 +1,24 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="512" height="512" version="1.1" viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<defs>
|
||||
<linearGradient id="a" x1="374.17" x2="170.64" y1="-112.67" y2="463" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#16bbf4" offset="0"/>
|
||||
<stop stop-color="#165ff2" offset=".99"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="b" x1="488.28" x2="474.29" y1="112.58" y2="556.15" xlink:href="#a"/>
|
||||
<linearGradient id="linearGradient206" x1="374.17" x2="170.64" y1="-112.67" y2="463" gradientUnits="userSpaceOnUse" xlink:href="#a"/>
|
||||
</defs>
|
||||
<g transform="matrix(.5796 0 0 .5796 66.717 93.438)">
|
||||
<g>
|
||||
<path d="m205.08 399.31h292.41a30 30 0 0 0 30-30v-339.31a30 30 0 0 0-30-30h-467.49a30 30 0 0 0-30 30v339.31a30 30 0 0 0 30 30h42a10 10 0 0 1 10 10v84.85a10 10 0 0 0 10.07 10 9.83 9.83 0 0 0 7-2.95l99-99a10 10 0 0 1 7.01-2.9z" fill="url(#linearGradient206)" style="isolation:isolate"/>
|
||||
<g fill="#ffffff">
|
||||
<path d="m160.43 213c-32.24-20-38.9-71.83-10.42-97.83 18.42-7.6 32.4 12.85 36.62 28.25 10.32 17.45 12.59 41-3.16 56.08a42.81 42.81 0 0 1-23.04 13.5z" style="isolation:isolate"/>
|
||||
<path d="m348.22 213.86c-21.73-15.31-45.37-29.75-71.77-35.15-33.1-4.41-70.73 5.36-91.7 32.87-14.83 14.32-18.34 36.94-5.49 53.76 8.52 19.48 5.59 45.78 28.23 56.94 16 15.83 40 1.27 56.32 14.21a7.6 7.6 0 0 0 5.59-5.05c-4.25-31.33 29.21-16.95 45.66-14.61 19.77-11.71 25.43-36.14 34.75-55.58 12.55-13.83 15-35.25-1.59-47.39z" style="isolation:isolate"/>
|
||||
<path d="m367 118.1c-21.87 2.52-29.89 28.17-40.34 44.42-10.67 20.94 12.26 38.77 28.48 47.89a19.63 19.63 0 0 0 13-1.07c18.86-10.12 26.86-33.43 27.34-53.79 0.24-16.78-8.3-38.93-28.48-37.45z" style="isolation:isolate"/>
|
||||
<path d="m218.7 176c-24-14.47-25.38-45.76-27.32-70.65-0.38-24 35.23-45.5 49.43-20.14 9.8 20.9 21.47 45.47 12.47 68.66-5.68 13.77-20.93 19.73-34.58 22.13z" style="isolation:isolate"/>
|
||||
<path d="m306.18 175.87c-28.48 0.84-43.29-32.4-35.93-56.83 0.17-19.58 7.31-53.56 33.53-48.18 28.29 10.94 34.3 49.46 20.82 74.07-6.77 10-6.2 25.11-18.42 30.94z" style="isolation:isolate"/>
|
||||
</g>
|
||||
</g>
|
||||
<path d="m633.15 225.66h-80.66a10 10 0 0 0-10 10v133.65a45 45 0 0 1-45 45h-185.19a10 10 0 0 0-10 10v47a20 20 0 0 0 19.95 20h194.47a6.65 6.65 0 0 1 4.7 1.95l65.83 65.74a6.65 6.65 0 0 0 11.35-4.7v-56.43a6.65 6.65 0 0 1 6.65-6.65h27.9a20 20 0 0 0 20-20v-225.61a20 20 0 0 0-20-19.95z" fill="url(#b)" style="isolation:isolate"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.4 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 6.6 KiB |
@@ -0,0 +1,16 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="512" height="512" version="1.1" viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<g>
|
||||
<g>
|
||||
<path d="m185.58 324.88h169.48a17.388 17.388 0 0 0 17.388-17.388v-196.66a17.388 17.388 0 0 0-17.388-17.388h-270.96a17.388 17.388 0 0 0-17.388 17.388v196.66a17.388 17.388 0 0 0 17.388 17.388h24.343a5.796 5.796 0 0 1 5.796 5.796v49.179a5.796 5.796 0 0 0 5.8366 5.796 5.6975 5.6975 0 0 0 4.0572-1.7098l57.38-57.38a5.796 5.796 0 0 1 4.063-1.6808z" fill="#000000" stroke-width=".5796" style="isolation:isolate"/>
|
||||
<g transform="matrix(.5796 0 0 .5796 66.717 93.438)" fill="#ffffff">
|
||||
<path d="m160.43 213c-32.24-20-38.9-71.83-10.42-97.83 18.42-7.6 32.4 12.85 36.62 28.25 10.32 17.45 12.59 41-3.16 56.08a42.81 42.81 0 0 1-23.04 13.5z" style="isolation:isolate"/>
|
||||
<path d="m348.22 213.86c-21.73-15.31-45.37-29.75-71.77-35.15-33.1-4.41-70.73 5.36-91.7 32.87-14.83 14.32-18.34 36.94-5.49 53.76 8.52 19.48 5.59 45.78 28.23 56.94 16 15.83 40 1.27 56.32 14.21a7.6 7.6 0 0 0 5.59-5.05c-4.25-31.33 29.21-16.95 45.66-14.61 19.77-11.71 25.43-36.14 34.75-55.58 12.55-13.83 15-35.25-1.59-47.39z" style="isolation:isolate"/>
|
||||
<path d="m367 118.1c-21.87 2.52-29.89 28.17-40.34 44.42-10.67 20.94 12.26 38.77 28.48 47.89a19.63 19.63 0 0 0 13-1.07c18.86-10.12 26.86-33.43 27.34-53.79 0.24-16.78-8.3-38.93-28.48-37.45z" style="isolation:isolate"/>
|
||||
<path d="m218.7 176c-24-14.47-25.38-45.76-27.32-70.65-0.38-24 35.23-45.5 49.43-20.14 9.8 20.9 21.47 45.47 12.47 68.66-5.68 13.77-20.93 19.73-34.58 22.13z" style="isolation:isolate"/>
|
||||
<path d="m306.18 175.87c-28.48 0.84-43.29-32.4-35.93-56.83 0.17-19.58 7.31-53.56 33.53-48.18 28.29 10.94 34.3 49.46 20.82 74.07-6.77 10-6.2 25.11-18.42 30.94z" style="isolation:isolate"/>
|
||||
</g>
|
||||
</g>
|
||||
<path d="m433.69 224.23h-46.751a5.796 5.796 0 0 0-5.796 5.796v77.464a26.082 26.082 0 0 1-26.082 26.082h-107.34a5.796 5.796 0 0 0-5.796 5.796v27.241a11.592 11.592 0 0 0 11.563 11.592h112.71a3.8543 3.8543 0 0 1 2.7241 1.1302l38.155 38.103a3.8543 3.8543 0 0 0 6.5785-2.7241v-32.707a3.8543 3.8543 0 0 1 3.8543-3.8543h16.171a11.592 11.592 0 0 0 11.592-11.592v-130.76a11.592 11.592 0 0 0-11.592-11.563z" fill="#000000" stroke-width=".5796" style="isolation:isolate"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.3 KiB |
+3
-7
@@ -1,4 +1,4 @@
|
||||
# Open-Chat-GPT REST Backend
|
||||
# Open-Assistant REST Backend
|
||||
|
||||
## REST Server Configuration
|
||||
|
||||
@@ -8,14 +8,10 @@ Example contents of a `.env` file for the backend:
|
||||
|
||||
```
|
||||
DATABASE_URI="postgresql://<username>:<password>@<host>/<database_name>"
|
||||
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.ocgpt.laion.ai", "https://stag.ocgpt.laion.ai", "https://ocgpt.laion.ai"]
|
||||
BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://localhost:3000", "http://localhost:8080", "https://localhost", "https://localhost:4200", "https://localhost:3000", "https://localhost:8080", "http://dev.oasst.laion.ai", "https://stag.oasst.laion.ai", "https://oasst.laion.ai"]
|
||||
|
||||
```
|
||||
|
||||
## Running the REST Server locally for development
|
||||
|
||||
First, install the requirements in `requirements.txt`.
|
||||
Then, run two terminals (note the working directory for each):
|
||||
|
||||
- Terminal 1, to go `backend/scripts` and run `docker-compose up`. This will start postgres.
|
||||
- Terminal 2, to go `backend` and run `scripts/run-local.sh`. This will start the REST server.
|
||||
Have a look into the main `README.md` file for more information on how to set up the backend for development.
|
||||
|
||||
+2
-2
@@ -8,7 +8,7 @@ script_location = %(here)s/alembic
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
@@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
|
||||
|
||||
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
|
||||
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
||||
|
||||
import sqlmodel
|
||||
from alembic import context
|
||||
from app import models # noqa: F401
|
||||
from oasst_backend import models # noqa: F401
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
@@ -68,6 +68,8 @@ def run_migrations_online() -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.get_context()._ensure_version_table()
|
||||
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""v1 db structure
|
||||
|
||||
Revision ID: cd7de470586e
|
||||
Revises: 23e5fea252dd
|
||||
Create Date: 2022-12-15 11:15:32.830225
|
||||
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cd7de470586e"
|
||||
down_revision = "23e5fea252dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# remove database objects
|
||||
op.drop_index(op.f("prompt_labeler_id"), table_name="prompt")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("labeler")
|
||||
op.drop_index(op.f("ix_service_client_api_key"), table_name="service_client")
|
||||
op.drop_table("service_client")
|
||||
|
||||
# wreate new database structure
|
||||
op.create_table(
|
||||
"api_client",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("api_key", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.String(256), nullable=False),
|
||||
sa.Column("admin_email", sa.String(256), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean, default=True, nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_api_client_api_key"), "api_client", ["api_key"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"person",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("username", sa.String(128), nullable=False), # unique in combination with api_client_id
|
||||
sa.Column("display_name", sa.String(256), nullable=False), # cached last seen display_name
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_person_username"), "person", ["api_client_id", "username"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"person_stats",
|
||||
sa.Column("person_id", UUID(as_uuid=True)),
|
||||
sa.Column("leader_score", sa.Integer, default=0, nullable=False), # determines position on leader board
|
||||
sa.Column("modified_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("reactions", sa.Integer, default=0, nullable=False), # reactions sent by user
|
||||
sa.Column("posts", sa.Integer, default=0, nullable=False), # posts sent by user
|
||||
sa.Column("upvotes", sa.Integer, default=0, nullable=False), # received upvotes (form other users)
|
||||
sa.Column("downvotes", sa.Integer, default=0, nullable=False), # received downvotes (from other users)
|
||||
sa.Column("work_reward", sa.Integer, default=0, nullable=False), # reward for workpackage completions
|
||||
sa.Column("compare_wins", sa.Integer, default=0, nullable=False), # num times user's post won compare tasks
|
||||
sa.Column("compare_losses", sa.Integer, default=0, nullable=False), # num times users's post lost compare tasks
|
||||
sa.PrimaryKeyConstraint("person_id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"work_package",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("expiry_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_work_package_person_id"), "work_package", ["person_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post",
|
||||
sa.Column("id", UUID(as_uuid=True), default=uuid.uuid4, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("parent_id", UUID(as_uuid=True), nullable=True), # root posts have NULL parent
|
||||
sa.Column("thread_id", UUID(as_uuid=True), nullable=False), # id of thread root
|
||||
sa.Column("workpackage_id", UUID(as_uuid=True), nullable=True), # workpackage id to pass to handler on reply
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=True), # sender (recipients are part of payload)
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("role", sa.String(128), nullable=False), # 'assistant', 'user' or something else
|
||||
sa.Column("frontend_post_id", sa.String(200), nullable=False), # unique together with api_client_id
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
op.create_index(op.f("ix_post_frontend_post_id"), "post", ["api_client_id", "frontend_post_id"], unique=True)
|
||||
op.create_index(op.f("ix_post_thread_id"), "post", ["thread_id"], unique=False)
|
||||
op.create_index(op.f("ix_post_workpackage_id"), "post", ["workpackage_id"], unique=False)
|
||||
op.create_index(op.f("ix_post_person_id"), "post", ["person_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post_reaction",
|
||||
sa.Column("post_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("person_id", UUID(as_uuid=True), nullable=False), # sender (recipients are part of payload)
|
||||
sa.Column("created_date", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("payload_type", sa.String(200), nullable=False), # deserialization hint & dbg aid
|
||||
sa.Column("payload", JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("api_client_id", UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("post_id", "person_id"),
|
||||
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"]),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"]),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("post_reaction")
|
||||
|
||||
op.drop_index("ix_post_person_id")
|
||||
op.drop_index("ix_post_workpackage_id")
|
||||
op.drop_index("ix_post_thread_id")
|
||||
op.drop_index("ix_post_frontend_post_id")
|
||||
op.drop_table("post")
|
||||
|
||||
op.drop_index("ix_work_package_person_id")
|
||||
op.drop_table("work_package")
|
||||
|
||||
op.drop_table("person_stats")
|
||||
|
||||
op.drop_index("ix_person_username")
|
||||
op.drop_table("person")
|
||||
|
||||
op.drop_index("ix_api_client_api_key")
|
||||
op.drop_table("api_client")
|
||||
|
||||
op.create_table(
|
||||
"service_client",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("service_admin_email", sa.String(128), nullable=True),
|
||||
sa.Column("api_key", sa.String(300), nullable=False),
|
||||
sa.Column("can_append", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("can_write", sa.Boolean, nullable=False, server_default="false"),
|
||||
sa.Column("can_delete", sa.Boolean, nullable=False, server_default="false"),
|
||||
sa.Column("can_read", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_service_client_api_key"), "service_client", ["api_key"], unique=True)
|
||||
|
||||
op.create_table(
|
||||
"labeler",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("display_name", sa.String(96), nullable=False),
|
||||
sa.Column("discord_username", sa.String(96), nullable=True),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
),
|
||||
sa.Column("is_enabled", sa.Boolean, nullable=False, server_default="true"),
|
||||
sa.Column("notes", sa.String(10 * 1024), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("discord_username"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer, sa.Identity()),
|
||||
sa.Column("labeler_id", sa.Integer, nullable=False),
|
||||
sa.Column("prompt", sa.Text, nullable=False),
|
||||
sa.Column("response", sa.Text, nullable=True),
|
||||
sa.Column("lang", sa.String(32), nullable=True),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["labeler_id"],
|
||||
["labeler.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("prompt_labeler_id"), "prompt", ["labeler_id"], unique=False)
|
||||
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add auth_method to person
|
||||
|
||||
Revision ID: 6368515778c5
|
||||
Revises: cd7de470586e
|
||||
Create Date: 2022-12-17 17:57:33.022549
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6368515778c5"
|
||||
down_revision = "cd7de470586e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("person", sa.Column("auth_method", sa.String(length=128), nullable=True))
|
||||
op.execute("UPDATE person SET auth_method = 'local'")
|
||||
op.alter_column("person", "auth_method", nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("person", "auth_method")
|
||||
# ### end Alembic commands ###
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_auth_method_to_ix_person_username
|
||||
|
||||
Revision ID: 0daec5f8135f
|
||||
Revises: 6368515778c5
|
||||
Create Date: 2022-12-22 18:35:59.609013
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa # noqa: F401
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0daec5f8135f"
|
||||
down_revision = "6368515778c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,50 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Generator
|
||||
|
||||
from app.database import engine
|
||||
from app.models import ServiceClient
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
with Session(engine) as db:
|
||||
yield db
|
||||
|
||||
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
api_key_header: str = Security(api_key_header),
|
||||
):
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
else:
|
||||
return api_key_header
|
||||
|
||||
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
create: bool = False,
|
||||
read: bool = True,
|
||||
update: bool = False,
|
||||
delete: bool = False,
|
||||
) -> ServiceClient:
|
||||
if api_key is not None:
|
||||
api_client = db.query(ServiceClient).filter(ServiceClient.api_key == api_key).first()
|
||||
if api_client is not None:
|
||||
if (
|
||||
(create is False or api_client.can_append)
|
||||
and (read is False or api_client.can_read)
|
||||
and (update is False or api_client.can_write)
|
||||
and (delete is False or api_client.can_delete)
|
||||
):
|
||||
return api_client
|
||||
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
@@ -1,7 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from app.api.v1 import labelers, prompts
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(labelers.router, prefix="/labelers", tags=["labelers"])
|
||||
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
@@ -1,114 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, List
|
||||
|
||||
from app import crud, schemas
|
||||
from app.api import deps
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[schemas.Labeler])
|
||||
def read_labelers(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
begin_id: int = 0,
|
||||
limit: int = 100,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve labelers.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
if limit > 10000:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
labelers = crud.labeler.get_multi(db, begin_id=begin_id, limit=limit)
|
||||
return labelers
|
||||
|
||||
|
||||
@router.post("/", response_model=schemas.Labeler)
|
||||
def create_labeler(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
item_in: schemas.LabelerCreate,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
item = crud.labeler.create(db=db, obj_in=item_in)
|
||||
return item
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=schemas.Labeler)
|
||||
def update_labeler(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
id: int,
|
||||
item_in: schemas.LabelerUpdate,
|
||||
) -> Any:
|
||||
"""
|
||||
Update a labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, update=True, read=True)
|
||||
item = crud.labeler.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
item = crud.labeler.update(db=db, db_obj=item, obj_in=item_in)
|
||||
return item
|
||||
|
||||
|
||||
@router.get("/by-username", response_model=schemas.Labeler)
|
||||
def read_labeler_by_username(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
discord_username: str,
|
||||
) -> Any:
|
||||
"""
|
||||
Get labeler by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
item = crud.labeler.get_by_discord_username(db=db, discord_username=discord_username)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=schemas.Labeler)
|
||||
def read_labeler(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
id: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Get labeler by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
item = crud.labeler.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/{id}", response_model=schemas.Labeler)
|
||||
def delete_labeler(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
id: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Delete a labeler.
|
||||
"""
|
||||
deps.api_auth(api_key, db, delete=True)
|
||||
labeler = crud.labeler.get(db=db, id=id)
|
||||
if not labeler:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
labeler = crud.labeler.remove(db=db, id=id)
|
||||
return labeler
|
||||
@@ -1,91 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, List
|
||||
|
||||
from app import crud, schemas
|
||||
from app.api import deps
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[schemas.Prompt])
|
||||
def read_prompts(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
begin_id: int = 0,
|
||||
limit: int = 1000,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve prompts.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
if limit > 10000:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
return crud.prompt.get_multi(db, begin_id=begin_id, limit=limit)
|
||||
|
||||
|
||||
@router.post("/", response_model=schemas.Prompt)
|
||||
def create_prompt(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
item_in: schemas.PromptCreate,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new prompt.
|
||||
"""
|
||||
deps.api_auth(api_key, db, create=True)
|
||||
if item_in.labeler_id is None:
|
||||
if item_in.discord_username is None:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Bad request")
|
||||
labeler = crud.labeler.get_by_discord_username(db=db, discord_username=item_in.discord_username)
|
||||
else:
|
||||
labeler = crud.labeler.get(db=db, id=item_in.labeler_id)
|
||||
|
||||
if labeler is None:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Invalid labeler user name")
|
||||
if not labeler.is_enabled:
|
||||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Labeler disabled")
|
||||
|
||||
item_in.labeler_id = labeler.id
|
||||
item_in.discord_username = None
|
||||
item = crud.prompt.create(db=db, obj_in=item_in)
|
||||
return item
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=schemas.Prompt)
|
||||
def read_prompt(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
id: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Get prompt by ID.
|
||||
"""
|
||||
deps.api_auth(api_key, db, read=True)
|
||||
item = crud.prompt.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/{id}", response_model=schemas.Prompt)
|
||||
def delete_prompt(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
id: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Delete a prompt.
|
||||
"""
|
||||
deps.api_auth(api_key, db, delete=True)
|
||||
item = crud.prompt.get(db=db, id=id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="Item not found")
|
||||
item = crud.prompt.remove(db=db, id=id)
|
||||
return item
|
||||
@@ -1,25 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# touch
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-chatGPT backend"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
|
||||
UPDATE_ALEMBIC: bool = True
|
||||
|
||||
@validator("BACKEND_CORS_ORIGINS", pre=True)
|
||||
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
elif isinstance(v, (list, str)):
|
||||
return v
|
||||
raise ValueError(v)
|
||||
|
||||
|
||||
settings = Settings(_env_file=".env")
|
||||
@@ -1,5 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .crud_labeler import labeler
|
||||
from .crud_prompt import prompt
|
||||
|
||||
__all__ = ["labeler", "prompt"]
|
||||
@@ -1,15 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.labeler import Labeler
|
||||
from app.schemas.labeler import LabelerCreate, LabelerUpdate
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class CRUDLabeler(CRUDBase[Labeler, LabelerCreate, LabelerUpdate]):
|
||||
def get_by_discord_username(self, db: Session, discord_username: str) -> Optional[Labeler]:
|
||||
return db.query(Labeler).filter(Labeler.discord_username == discord_username).first()
|
||||
|
||||
|
||||
labeler = CRUDLabeler(Labeler)
|
||||
@@ -1,11 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.prompt import Prompt
|
||||
from app.schemas.prompt import PromptCreate
|
||||
|
||||
|
||||
class CRUDPrompt(CRUDBase[Prompt, PromptCreate, None]):
|
||||
pass
|
||||
|
||||
|
||||
prompt = CRUDPrompt(Prompt)
|
||||
@@ -1,6 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .labeler import Labeler
|
||||
from .prompt import Prompt
|
||||
from .service_client import ServiceClient
|
||||
|
||||
__all__ = ["Labeler", "Prompt", "ServiceClient"]
|
||||
@@ -1,19 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Labeler(SQLModel, table=True):
|
||||
__tablename__ = "labeler"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
display_name: str
|
||||
discord_username: str
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
nullable=False,
|
||||
)
|
||||
is_enabled: bool
|
||||
notes: str
|
||||
@@ -1,19 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Prompt(SQLModel, table=True):
|
||||
__tablename__ = "prompt"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
labeler_id: Optional[int] = Field(default=None, foreign_key="labeler.id")
|
||||
prompt: str
|
||||
response: Optional[str]
|
||||
lang: Optional[str]
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class ServiceClient(SQLModel, table=True):
|
||||
__tablename__ = "service_client"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
api_key: str
|
||||
service_admin_email: Optional[str] = None
|
||||
api_key: str
|
||||
can_append: bool = True
|
||||
can_write: bool = False
|
||||
can_delete: bool = False
|
||||
can_read: bool = True
|
||||
@@ -1,5 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .labeler import Labeler, LabelerCreate, LabelerUpdate
|
||||
from .prompt import Prompt, PromptCreate
|
||||
|
||||
__all__ = ["Labeler", "LabelerCreate", "LabelerUpdate", "Prompt", "PromptCreate"]
|
||||
@@ -1,28 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Labeler(BaseModel):
|
||||
id: int
|
||||
discord_username: str
|
||||
display_name: str
|
||||
created_date: datetime
|
||||
is_enabled: str
|
||||
notes: Optional[str]
|
||||
|
||||
|
||||
class LabelerCreate(BaseModel):
|
||||
discord_username: str
|
||||
display_name: Optional[str]
|
||||
is_enabled: Optional[bool] = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class LabelerUpdate(BaseModel):
|
||||
discord_username: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
@@ -1,22 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
id: int
|
||||
labeler_id: int
|
||||
prompt: str
|
||||
response: Optional[str]
|
||||
lang: Optional[str]
|
||||
created_date: datetime
|
||||
|
||||
|
||||
class PromptCreate(BaseModel):
|
||||
labeler_id: Optional[int] = None
|
||||
discord_username: Optional[str] = None
|
||||
prompt: str
|
||||
response: Optional[str] = None
|
||||
lang: Optional[str] = None
|
||||
@@ -1,14 +0,0 @@
|
||||
FROM python:3.9
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
||||
|
||||
COPY ./app /code/app
|
||||
|
||||
COPY ./app /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
@@ -1,3 +0,0 @@
|
||||
FROM postgres:15
|
||||
|
||||
COPY ./scripts/create-db.sh /docker-entrypoint-initdb.d/
|
||||
@@ -4,9 +4,9 @@ from pathlib import Path
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
from app.api.v1.api import api_router
|
||||
from app.config import settings
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
||||
@@ -27,7 +27,7 @@ if settings.UPDATE_ALEMBIC:
|
||||
def alembic_upgrade():
|
||||
logger.info("Attempting to upgrade alembic on startup")
|
||||
try:
|
||||
alembic_ini_path = Path(__file__).parent.parent / "alembic.ini"
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.DATABASE_URI)
|
||||
alembic.command.upgrade(alembic_cfg, "head")
|
||||
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.models import ApiClient
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_db() -> Generator:
|
||||
with Session(engine) as db:
|
||||
yield db
|
||||
|
||||
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
api_key_header: str = Security(api_key_header),
|
||||
):
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
else:
|
||||
return api_key_header
|
||||
|
||||
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
|
||||
if api_key is not None:
|
||||
if settings.ALLOW_ANY_API_KEY:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
return api_client
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
|
||||
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import tasks
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models.db_payload import TaskPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_task(request: protocol_schema.TaskRequest) -> protocol_schema.Task:
|
||||
match request.type:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
logger.info("Frontend requested a random task.")
|
||||
while request.type == protocol_schema.TaskRequestType.random:
|
||||
request.type = random.choice(list(protocol_schema.TaskRequestType)).value
|
||||
return generate_task(request)
|
||||
case protocol_schema.TaskRequestType.summarize_story:
|
||||
logger.info("Generating a SummarizeStoryTask.")
|
||||
task = protocol_schema.SummarizeStoryTask(
|
||||
story="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rate_summary:
|
||||
logger.info("Generating a RateSummaryTask.")
|
||||
task = protocol_schema.RateSummaryTask(
|
||||
full_text="This is a story. A very long story. So long, it needs to be summarized.",
|
||||
summary="This is a summary.",
|
||||
scale=protocol_schema.RatingScale(min=1, max=5),
|
||||
)
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
logger.info("Generating an InitialPromptTask.")
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
task = protocol_schema.UserReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, write me an English essay about water.",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
task = protocol_schema.RankInitialPromptsTask(
|
||||
prompts=[
|
||||
"Please write a story about a time you were happy.",
|
||||
"Please write a story about a time you were sad.",
|
||||
]
|
||||
)
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
protocol_schema.ConversationMessage(
|
||||
text="I'm not sure I understood correctly, could you rephrase that?",
|
||||
is_assistant=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
replies=[
|
||||
"Oh come oooooon!",
|
||||
"What are the news?",
|
||||
],
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text="Hey, assistant, what's going on in the world?",
|
||||
is_assistant=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
replies=[
|
||||
"I'm not sure I understood correctly, could you rephrase that?",
|
||||
"The world is fine. All good.",
|
||||
"Crap is hitting the fan. Start farming.",
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request type.",
|
||||
)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: protocol_schema.TaskRequest,
|
||||
) -> Any:
|
||||
"""
|
||||
Create new task.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
task = generate_task(request)
|
||||
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr.store_task(task)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
|
||||
# here we store the post id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack")
|
||||
def acknowledge_task_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports failure to implement a task.
|
||||
"""
|
||||
deps.api_auth(api_key, db)
|
||||
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
# here we would store the post id in the database for the task
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.AnyInteraction,
|
||||
) -> Any:
|
||||
"""
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
|
||||
work_payload: TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# here we store the text reply in the database
|
||||
# ToDo: role user or agent?
|
||||
pr.store_text_reply(interaction, role="unknown")
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the rating in the database
|
||||
pr.store_rating(interaction)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# TODO: check if the ranking is valid
|
||||
pr.store_ranking(interaction)
|
||||
# here we would store the ranking in the database
|
||||
return protocol_schema.TaskDone()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid response type.",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Interaction request failed.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-assistant backend"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_USER: str = "postgres"
|
||||
POSTGRES_PASSWORD: str = "postgres"
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
ALLOW_ANY_API_KEY: bool = False
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
return PostgresDsn.build(
|
||||
scheme="postgresql",
|
||||
user=values.get("POSTGRES_USER"),
|
||||
password=values.get("POSTGRES_PASSWORD"),
|
||||
host=values.get("POSTGRES_HOST"),
|
||||
port=values.get("POSTGRES_PORT"),
|
||||
path=f"/{values.get('POSTGRES_DB') or ''}",
|
||||
)
|
||||
|
||||
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
|
||||
UPDATE_ALEMBIC: bool = True
|
||||
|
||||
@validator("BACKEND_CORS_ORIGINS", pre=True)
|
||||
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
elif isinstance(v, (list, str)):
|
||||
return v
|
||||
raise ValueError(v)
|
||||
|
||||
|
||||
settings = Settings(_env_file=".env")
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
__all__ = []
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from app.config import settings
|
||||
from oasst_backend.config import settings
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .work_package import WorkPackage
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"Person",
|
||||
"PersonStats",
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
]
|
||||
@@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class ApiClient(SQLModel, table=True):
|
||||
__tablename__ = "api_client"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
api_key: str = Field(max_length=512, index=True, unique=True)
|
||||
description: str = Field(max_length=256)
|
||||
admin_email: Optional[str] = Field(max_length=256, nullable=True)
|
||||
enabled: bool = Field(default=True)
|
||||
@@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal
|
||||
|
||||
from oasst_backend.models.payload_column_type import payload_type
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@payload_type
|
||||
class TaskPayload(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class SummarizationStoryPayload(TaskPayload):
|
||||
type: Literal["summarize_story"] = "summarize_story"
|
||||
story: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RateSummaryPayload(TaskPayload):
|
||||
type: Literal["rate_summary"] = "rate_summary"
|
||||
full_text: str
|
||||
summary: str
|
||||
scale: protocol_schema.RatingScale
|
||||
|
||||
|
||||
@payload_type
|
||||
class InitialPromptPayload(TaskPayload):
|
||||
type: Literal["initial_prompt"] = "initial_prompt"
|
||||
hint: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
|
||||
@payload_type
|
||||
class AssistantReplyPayload(TaskPayload):
|
||||
type: Literal["assistant_reply"] = "assistant_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class ReactionPayload(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
rating: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankConversationRepliesPayload(TaskPayload):
|
||||
conversation: protocol_schema.Conversation # the conversation so far
|
||||
replies: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankInitialPromptsPayload(TaskPayload):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankAssistantRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of assistant replies to a conversation."""
|
||||
|
||||
type: Literal["rank_assistant_replies"] = "rank_assistant_replies"
|
||||
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, parse_obj_as, validator
|
||||
from pydantic.main import ModelMetaclass
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
payload_type_registry = {}
|
||||
|
||||
|
||||
P = TypeVar("P", bound=BaseModel)
|
||||
|
||||
|
||||
def payload_type(cls: Type[P]) -> Type[P]:
|
||||
payload_type_registry[cls.__name__] = cls
|
||||
return cls
|
||||
|
||||
|
||||
class PayloadContainer(BaseModel):
|
||||
payload_type: str = ""
|
||||
payload: BaseModel = None
|
||||
|
||||
def __init__(self, **v):
|
||||
p = v["payload"]
|
||||
if isinstance(p, dict):
|
||||
t = v["payload_type"]
|
||||
if t not in payload_type_registry:
|
||||
raise RuntimeError(f"Payload type '{t}' not registered")
|
||||
cls = payload_type_registry[t]
|
||||
v["payload"] = cls(**p)
|
||||
super().__init__(**v)
|
||||
|
||||
@validator("payload", pre=True)
|
||||
def check_payload(cls, v: BaseModel, values: dict[str, Any]) -> BaseModel:
|
||||
values["payload_type"] = type(v).__name__
|
||||
return v
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def payload_column_type(pydantic_type):
|
||||
class PayloadJSONBType(TypeDecorator, Generic[T]):
|
||||
impl = pg.JSONB()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_encoder=json,
|
||||
):
|
||||
self.json_encoder = json_encoder
|
||||
super(PayloadJSONBType, self).__init__()
|
||||
|
||||
# serialize
|
||||
def bind_processor(self, dialect):
|
||||
impl_processor = self.impl.bind_processor(dialect)
|
||||
dumps = self.json_encoder.dumps
|
||||
|
||||
def process(value: T):
|
||||
if value is not None:
|
||||
if isinstance(pydantic_type, ModelMetaclass):
|
||||
# This allows to assign non-InDB models and if they're
|
||||
# compatible, they're directly parsed into the InDB
|
||||
# representation, thus hiding the implementation in the
|
||||
# background. However, the InDB model will still be returned
|
||||
value_to_dump = pydantic_type.from_orm(value)
|
||||
else:
|
||||
value_to_dump = value
|
||||
|
||||
value = jsonable_encoder(value_to_dump)
|
||||
|
||||
if impl_processor:
|
||||
return impl_processor(value)
|
||||
else:
|
||||
return dumps(jsonable_encoder(value_to_dump))
|
||||
|
||||
return process
|
||||
|
||||
# deserialize
|
||||
def result_processor(self, dialect, coltype) -> T:
|
||||
impl_processor = self.impl.result_processor(dialect, coltype)
|
||||
|
||||
def process(value):
|
||||
if impl_processor:
|
||||
value = impl_processor(value)
|
||||
if value is None:
|
||||
return None
|
||||
# Explicitly use the generic directly, not type(T)
|
||||
full_obj = parse_obj_as(pydantic_type, value)
|
||||
return full_obj
|
||||
|
||||
return process
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
return PayloadJSONBType
|
||||
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class Person(SQLModel, table=True):
|
||||
__tablename__ = "person"
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
username: str = Field(nullable=False, max_length=128)
|
||||
auth_method: str = Field(nullable=False, max_length=128, default="local")
|
||||
display_name: str = Field(nullable=False, max_length=256)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PersonStats(SQLModel, table=True):
|
||||
__tablename__ = "person_stats"
|
||||
|
||||
person_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
|
||||
)
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
reactions: int = 0 # reactions sent by user
|
||||
posts: int = 0 # posts sent by user
|
||||
upvotes: int = 0 # received upvotes (form other users)
|
||||
downvotes: int = 0 # received downvotes (from other users)
|
||||
work_reward: int = 0 # reward for workpackage completions
|
||||
compare_wins: int = 0 # num times user's post won compare tasks
|
||||
compare_losses: int = 0 # num times users's post lost compare tasks
|
||||
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class Post(SQLModel, table=True):
|
||||
__tablename__ = "post"
|
||||
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
thread_id: UUID = Field(nullable=False, index=True)
|
||||
workpackage_id: UUID = Field(nullable=True, index=True)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_post_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
|
||||
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class PostReaction(SQLModel, table=True):
|
||||
__tablename__ = "post_reaction"
|
||||
|
||||
post_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
person_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class WorkPackage(SQLModel, table=True):
|
||||
__tablename__ = "work_package"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
@@ -0,0 +1,316 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
return None
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
.filter(
|
||||
Person.api_client_id == self.api_client.id,
|
||||
Person.username == user.id,
|
||||
Person.auth_method == user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if person is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(
|
||||
username=user.id,
|
||||
display_name=user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=user.auth_method,
|
||||
)
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
elif user.display_name and user.display_name != person.display_name:
|
||||
# we found the user but the display name changed
|
||||
person.display_name = user.display_name
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
return person
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise TypeError(f"post_id must be string, not {type(post_id)}")
|
||||
if not post_id:
|
||||
raise ValueError("post_id must not be empty")
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise KeyError(f"WorkPackage for task {task_id} not found")
|
||||
if work_pack.expiry_date is not None and datetime.utcnow() > work_pack.expiry_date:
|
||||
raise RuntimeError("WorkPackage already expired.")
|
||||
|
||||
# ToDo: check race-condition, transaction
|
||||
|
||||
# check if task thread exits
|
||||
thread_root = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.workpackage_id == work_pack.id,
|
||||
Post.frontend_post_id == post_id,
|
||||
Post.parent_id is None,
|
||||
Post.api_client_id == self.api_client.id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if thread_root is None:
|
||||
thread_id = uuid4()
|
||||
thread_root = self.insert_post(
|
||||
post_id=thread_id,
|
||||
thread_id=thread_id,
|
||||
frontend_post_id=post_id,
|
||||
parent_id=None,
|
||||
role="system",
|
||||
workpackage_id=work_pack.id,
|
||||
payload=None,
|
||||
payload_type="bind",
|
||||
)
|
||||
return thread_root
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise KeyError(f"Post with post_id {frontend_post_id} not found.")
|
||||
return post
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
self.validate_post_id(post_id)
|
||||
post = self.fetch_post_by_frontend_post_id(post_id, fail_if_missing=True)
|
||||
work_pack = self.db.query(WorkPackage).filter(WorkPackage.id == post.workpackage_id).one()
|
||||
return work_pack
|
||||
|
||||
def store_text_reply(self, reply: protocol_schema.TextReplyToPost, role: str) -> Post:
|
||||
self.validate_post_id(reply.post_id)
|
||||
self.validate_post_id(reply.user_post_id)
|
||||
|
||||
# find post with post-id
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(
|
||||
Post.api_client_id == self.api_client.id,
|
||||
Post.frontend_post_id == reply.post_id,
|
||||
# Post.person_id == self.person_id
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if parent_post is None:
|
||||
raise KeyError(f"Post for post_id {reply.post_id} not found.")
|
||||
|
||||
# create reply post
|
||||
user_post_id = uuid4()
|
||||
user_post = self.insert_post(
|
||||
post_id=user_post_id,
|
||||
frontend_post_id=reply.user_post_id,
|
||||
parent_id=parent_post.id,
|
||||
thread_id=parent_post.thread_id,
|
||||
workpackage_id=parent_post.workpackage_id,
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=reply.text),
|
||||
)
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}"
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
raise ValueError(f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}")
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(ranking.post_id, fail_if_missing=True)
|
||||
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
|
||||
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=})."
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
|
||||
)
|
||||
|
||||
def store_task(self, task: protocol_schema.Task) -> WorkPackage:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Invalid task type: {type(task)=}")
|
||||
|
||||
wp = self.insert_work_package(payload=payload, id=task.id)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
|
||||
def insert_work_package(self, payload: db_payload.TaskPayload, id: UUID = None) -> WorkPackage:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
id=id,
|
||||
person_id=self.person_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
|
||||
def insert_post(
|
||||
self,
|
||||
*,
|
||||
post_id: UUID,
|
||||
frontend_post_id: str,
|
||||
parent_id: UUID,
|
||||
thread_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.PostPayload,
|
||||
payload_type: str = None,
|
||||
) -> Post:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
post = Post(
|
||||
id=post_id,
|
||||
parent_id=parent_id,
|
||||
thread_id=thread_id,
|
||||
workpackage_id=workpackage_id,
|
||||
person_id=self.person_id,
|
||||
role=role,
|
||||
frontend_post_id=frontend_post_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
|
||||
def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
raise ValueError("User required")
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
post_id=post_id,
|
||||
person_id=self.person_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
CREATE DATABASE ocgpt_backend;
|
||||
EOSQL
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export DATABASE_URI=postgresql://postgres:postgres@localhost:5432/postgres
|
||||
|
||||
uvicorn app.main:app --reload
|
||||
@@ -1,13 +0,0 @@
|
||||
install:
|
||||
python -m pip install -U pip
|
||||
python -m pip install -e .
|
||||
|
||||
lint: ## [Local development] Run pylint and black
|
||||
python -m pylint app
|
||||
python -m black --check -l 120 app
|
||||
|
||||
black: ## [Local development] Auto-format python code using black
|
||||
python -m black -l 120 .
|
||||
|
||||
run:
|
||||
python -m bot
|
||||
@@ -1,14 +0,0 @@
|
||||
# open-chat-gpt
|
||||
|
||||
This is the github repo for the open-chat-gpt project.
|
||||
We are currently building a discord bot in order to make everyone contribute with great prompts and answers.
|
||||
Join us!
|
||||
https://discord.gg/ZUfPw6jP
|
||||
|
||||
## Project description
|
||||
|
||||
We are calling the community for help to collect ChatGPT-like Instruction-Fulfillment datasamples via Discord. People can post Instructions they think would make sense for ChatGPT-like systems & also provide a good reference answer for it.
|
||||
|
||||
## Todo
|
||||
|
||||
Figure out ouath flow for the app to work inside the open-chat-gpt testing channel here. https://discord.gg/JJSKtRhv
|
||||
-207
@@ -1,207 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import discord
|
||||
import requests
|
||||
from discord import app_commands
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
bot_url = "https://discord.com/api/oauth2/authorize?client_id=1051614245940375683&permissions=8&scope=bot"
|
||||
|
||||
# Load up all the important environment variables.
|
||||
load_dotenv()
|
||||
|
||||
# For authentication.
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
|
||||
# For Backends.
|
||||
API_SERVER_URL = os.getenv("API_SERVER_URL")
|
||||
API_SERVER_KEY = os.getenv("API_SERVER_KEY")
|
||||
|
||||
labelers_url = f"{API_SERVER_URL}/api/v1/labelers/"
|
||||
prompts_url = f"{API_SERVER_URL}/api/v1/prompts/"
|
||||
headers = {"X-API-Key": API_SERVER_KEY}
|
||||
|
||||
# For testing only.
|
||||
TEST_GUILD = os.getenv("TEST_GUILD")
|
||||
|
||||
|
||||
# Initiate the client and command tree to create slash commands.
|
||||
class OpenChatGPTClient(discord.Client):
|
||||
def __init__(self, *, intents: discord.Intents):
|
||||
super().__init__(intents=intents)
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
|
||||
async def setup_hook(self):
|
||||
if TEST_GUILD:
|
||||
# When testing the bot it's handy to run in a single server (called a
|
||||
# Guide in the API). This is relatively fast.
|
||||
guild = discord.Object(id=TEST_GUILD)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
else:
|
||||
# This can take up to an hour for the commands to be registered.
|
||||
await self.tree.sync()
|
||||
logger.debug("Ready!")
|
||||
|
||||
|
||||
# List the set of intents needed for commands to operate properly.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = OpenChatGPTClient(intents=intents)
|
||||
|
||||
|
||||
class LikeButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👍")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
await interaction.response.send_message("Thanks for your feedback. You liked this 👍 ")
|
||||
|
||||
|
||||
class NeutralButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="😐")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
await interaction.response.send_message("Thanks for your feedback. You thought this was neutral 😐 ")
|
||||
|
||||
|
||||
class DislikeButton(discord.ui.Button):
|
||||
def __init__(self, label, channel, username, prompt):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👎")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.prompt = prompt
|
||||
|
||||
async def callback(self, interaction):
|
||||
# interaction holds the interaction object
|
||||
# await interaction.response.defer()
|
||||
# send the feedback to the backend #
|
||||
await interaction.response.send_message("Thanks for your feedback. You disliked this 👎 ")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def register(interaction: discord.Interaction):
|
||||
"""Registers the user for submissions."""
|
||||
labeler = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"display_name": interaction.user.name,
|
||||
"is_enabled": True,
|
||||
}
|
||||
response = requests.post(labelers_url, headers=headers, json=labeler)
|
||||
if response.status_code == 200:
|
||||
await interaction.response.send_message(f"Added you {interaction.user.name}")
|
||||
else:
|
||||
logger.debug(response)
|
||||
await interaction.response.send_message("Failed to add you")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def list_participants(interaction: discord.Interaction):
|
||||
"""Reports the set of registered participants."""
|
||||
response = requests.get(labelers_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
names = ",".join([labeler["display_name"] for labeler in response.json()])
|
||||
await interaction.response.send_message(f"Found these users: {names}")
|
||||
else:
|
||||
await interaction.response.send_message("Failed to fetch participants")
|
||||
|
||||
|
||||
async def send_prompt_with_response_and_button(channel, username, prompt, response):
|
||||
await channel.send(f"What do you think about the following interaction: \nprompt: {prompt} \nresponse: {response}")
|
||||
# await channel.send(f'Please click on the button that best describes your reaction to the response:')
|
||||
|
||||
# add buttons
|
||||
view = discord.ui.View()
|
||||
like = LikeButton(label="Like", channel=channel, username=username, prompt=prompt)
|
||||
neutral = NeutralButton(label="Neutral", channel=channel, username=username, prompt=prompt)
|
||||
dislike = DislikeButton(label="Dislike", channel=channel, username=username, prompt=prompt)
|
||||
|
||||
view.add_item(item=like)
|
||||
view.add_item(item=neutral)
|
||||
view.add_item(item=dislike)
|
||||
await channel.send(view=view)
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def review_prompts(interaction: discord.Interaction, number_of_prompts: int):
|
||||
# get the prompt from the db
|
||||
url = f"{prompts_url}?begin_id=0&limit={number_of_prompts}"
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
prompts = response.json()
|
||||
logger.debug("the responses are:", prompts)
|
||||
for prompt in prompts:
|
||||
await send_prompt_with_response_and_button(
|
||||
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message("Failed to get prompts for review")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def add_prompt(interaction: discord.Interaction, prompt: str, response: str, language: str = "en"):
|
||||
"""Uploads a single prompt to the server."""
|
||||
prompt = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"labeler_id": 5,
|
||||
"prompt": prompt,
|
||||
"response": response,
|
||||
"lang": language,
|
||||
}
|
||||
response = requests.post(prompts_url, headers=headers, json=prompt)
|
||||
if response.status_code == 200:
|
||||
await send_prompt_with_response_and_button(
|
||||
interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"]
|
||||
)
|
||||
# send the prompt back with buttons for the user to click on
|
||||
# await interaction.response.send_message("Added your prompt")
|
||||
else:
|
||||
await interaction.response.send_message("Failed to add the prompt")
|
||||
|
||||
|
||||
@client.tree.command()
|
||||
async def add_prompts_set(interaction: discord.Interaction, prompts: discord.Attachment):
|
||||
"""Uploads a batch of prompts to the server."""
|
||||
# Loading a bunch of prompts from a file can take a while. So first defer
|
||||
# the response to ensure we're able to later tell the user what happened.
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Read the prompts and load them one by one.
|
||||
# TODO: Upload a batch when the API supports it.
|
||||
# TODO: Handle incorrect file types and parsing errors.
|
||||
prompts_raw = await prompts.read()
|
||||
prompts_loaded = json.loads(prompts_raw)
|
||||
count = 0
|
||||
for entry in prompts_loaded:
|
||||
for response in entry["responses"]:
|
||||
prompt = {
|
||||
"discord_username": f"{interaction.user.id}",
|
||||
"labeler_id": 5,
|
||||
"prompt": entry["prompt"],
|
||||
"response": response,
|
||||
"lang": "en",
|
||||
}
|
||||
response = requests.post(prompts_url, headers=headers, json=prompt)
|
||||
if response.status_code != 200:
|
||||
await interaction.followup.send("Failed to upload")
|
||||
return
|
||||
count += 1
|
||||
await interaction.followup.send(f"Loaded up {count} prompts")
|
||||
|
||||
|
||||
client.run(TOKEN)
|
||||
@@ -1,2 +0,0 @@
|
||||
discord.py==2.1.0
|
||||
python-dotenv==0.21.0
|
||||
@@ -1,29 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
def _read_reqs(relpath):
|
||||
fullpath = os.path.join(os.path.dirname(__file__), relpath)
|
||||
with open(fullpath) as f:
|
||||
return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
|
||||
|
||||
REQUIREMENTS = _read_reqs("requirements.txt")
|
||||
|
||||
setup(
|
||||
name="open-chat-gpt",
|
||||
packages=find_packages(),
|
||||
version="0.0.1",
|
||||
license="Apache 2.0",
|
||||
description="A Discord Bot for collecting and ranking prompts to train an Open ChatGPT",
|
||||
keywords=["machine learning", "natural language processing", "discord"],
|
||||
install_requires=REQUIREMENTS,
|
||||
classifiers=[
|
||||
"Development Status :: Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: Apache License",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
[
|
||||
{
|
||||
"prompt": "tell me the name of two dogs",
|
||||
"responses": ["Charles", "bobby"]
|
||||
},
|
||||
{
|
||||
"prompt": "Name one type of cheese made in france",
|
||||
"responses": ["Munster", "Gouda"]
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
# Open-Assistant Data Collection Discord Bot
|
||||
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
|
||||
## Invite official bot
|
||||
|
||||
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
|
||||
|
||||
## Bot token for development
|
||||
|
||||
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
|
||||
|
||||
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
|
||||
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
|
||||
|
||||
The simplest way to configure the token is via an `.env` file:
|
||||
|
||||
```
|
||||
BOT_TOKEN=XYZABC123...
|
||||
```
|
||||
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from bot import OpenAssistantBot
|
||||
from bot_settings import settings
|
||||
|
||||
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = OpenAssistantBot(
|
||||
settings.BOT_TOKEN,
|
||||
bot_channel_name=settings.BOT_CHANNEL_NAME,
|
||||
backend_url=settings.BACKEND_URL,
|
||||
api_key=settings.API_KEY,
|
||||
owner_id=settings.OWNER_ID,
|
||||
template_dir=settings.TEMPLATE_DIR,
|
||||
debug=settings.DEBUG,
|
||||
)
|
||||
bot.run()
|
||||
@@ -0,0 +1,74 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
self.task_models_map = task_models_map
|
||||
|
||||
def post(self, path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _parse_task(self, data: dict) -> protocol_schema.Task:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("dict expected")
|
||||
|
||||
task_type = data.get("type")
|
||||
if task_type not in self.task_models_map:
|
||||
raise RuntimeError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user)
|
||||
|
||||
def ack_task(self, task_id: str, post_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
||||
|
||||
def nack_task(self, task_id: str, reason: str) -> None:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
@@ -0,0 +1,283 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import discord
|
||||
import task_handlers
|
||||
from api_client import ApiClient, TaskType
|
||||
from bot_base import BotBase
|
||||
from discord import app_commands
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import get_git_head_hash, utcnow
|
||||
|
||||
__version__ = "0.0.3"
|
||||
BOT_NAME = "Open-Assistant Junior"
|
||||
|
||||
|
||||
class OpenAssistantBot(BotBase):
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str,
|
||||
bot_channel_name: str,
|
||||
backend_url: str,
|
||||
api_key: str,
|
||||
owner_id: Optional[Union[int, str]] = None,
|
||||
template_dir: str = "./templates",
|
||||
debug: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.template_dir = Path(template_dir)
|
||||
self.bot_channel_name = bot_channel_name
|
||||
self.templates = MessageTemplates(template_dir)
|
||||
self.debug = debug
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
if isinstance(owner_id, str):
|
||||
owner_id = int(owner_id)
|
||||
self.owner_id = owner_id
|
||||
|
||||
self.bot_token = bot_token
|
||||
client = discord.Client(intents=intents)
|
||||
self.client = client
|
||||
self.loop = client.loop
|
||||
|
||||
self.bot_channel: discord.TextChannel = None
|
||||
self.backend = ApiClient(backend_url, api_key)
|
||||
|
||||
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
||||
logger.info(f"{client.user} is now running!")
|
||||
|
||||
await self.delete_all_old_bot_messages()
|
||||
# if self.debug:
|
||||
# await self.post_boot_message()
|
||||
await self.post_welcome_message()
|
||||
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message):
|
||||
# ignore own messages
|
||||
if message.author != client.user:
|
||||
await self.handle_message(message)
|
||||
|
||||
@self.tree.command()
|
||||
async def tutorial(interaction: discord.Interaction):
|
||||
"""Start the Open-Assistant tutorial via DMs."""
|
||||
|
||||
dm = await self.client.create_dm(discord.Object(interaction.user.id))
|
||||
await dm.send("Tutorial coming soon... :-)")
|
||||
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
|
||||
|
||||
@self.tree.command()
|
||||
async def help(interaction: discord.Interaction):
|
||||
"""Sends the user a list of all available commands"""
|
||||
await self.post_help(interaction.user)
|
||||
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
|
||||
|
||||
@self.tree.command()
|
||||
async def work(interaction: discord.Interaction):
|
||||
"""Request a new personalized task"""
|
||||
|
||||
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
q = task_handlers.Questionnaire()
|
||||
await interaction.response.send_modal(q)
|
||||
|
||||
async def post_help(self, user: discord.abc.User) -> discord.Message:
|
||||
is_bot_owner = user.id == self.owner_id
|
||||
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
|
||||
|
||||
async def post_boot_message(self) -> discord.Message:
|
||||
return await self.post_template(
|
||||
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
||||
)
|
||||
|
||||
async def post_welcome_message(self) -> discord.Message:
|
||||
return await self.post_template("welcome.msg")
|
||||
|
||||
async def delete_all_old_bot_messages(self) -> None:
|
||||
logger.info("Deleting old threads...")
|
||||
for thread in self.bot_channel.threads:
|
||||
if thread.owner_id == self.client.user.id:
|
||||
await thread.delete()
|
||||
logger.info("Completed deleting old theards.")
|
||||
|
||||
logger.info("Deleting old messages...")
|
||||
look_until = utcnow() - timedelta(days=365)
|
||||
async for msg in self.bot_channel.history(limit=None):
|
||||
msg: discord.Message
|
||||
if msg.created_at < look_until:
|
||||
break
|
||||
if msg.author.id == self.client.user.id:
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old messages.")
|
||||
|
||||
async def next_task(self):
|
||||
task_type = protocol_schema.TaskRequestType.random
|
||||
task = self.backend.fetch_task(task_type, user=None)
|
||||
|
||||
handler: task_handlers.ChannelTaskBase = None
|
||||
match task.type:
|
||||
case TaskType.summarize_story:
|
||||
handler = task_handlers.SummarizeStoryHandler()
|
||||
case TaskType.rate_summary:
|
||||
handler = task_handlers.RateSummaryHandler()
|
||||
case TaskType.initial_prompt:
|
||||
handler = task_handlers.InitialPromptHandler()
|
||||
case TaskType.user_reply:
|
||||
handler = task_handlers.UserReplyHandler()
|
||||
case TaskType.assistant_reply:
|
||||
handler = task_handlers.AssistantReplyHandler()
|
||||
case TaskType.rank_initial_prompts:
|
||||
handler = task_handlers.RankInitialPromptsHandler()
|
||||
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
||||
handler = task_handlers.RankConversationsHandler()
|
||||
case _:
|
||||
logger.warning(f"Unsupported task type received: {task.type}")
|
||||
self.backend.nack_task(task.id, "not supported")
|
||||
|
||||
if handler:
|
||||
try:
|
||||
logger.info(f"strarting task {task.id}")
|
||||
msg = await handler.start(self, task)
|
||||
self.backend.ack_task(task.id, msg.id)
|
||||
except Exception:
|
||||
logger.exception("Starting task failed.")
|
||||
self.backend.nack_task(task.id, "faled")
|
||||
|
||||
async def background_timer(self):
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
next_fetch_task = utcnow() + timedelta(seconds=1)
|
||||
while True:
|
||||
now = utcnow()
|
||||
|
||||
if self.bot_channel:
|
||||
if now > next_fetch_task:
|
||||
next_fetch_task = utcnow() + timedelta(seconds=60)
|
||||
|
||||
try:
|
||||
await self.next_task()
|
||||
except Exception:
|
||||
logger.exception("fetching next task failed")
|
||||
|
||||
for x in self.reply_handlers.values():
|
||||
x.handler.tick(now)
|
||||
|
||||
if now > next_remove_completed:
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
await self.remove_completed_handlers()
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _sync(self, command: str, message: discord.Message):
|
||||
|
||||
logger.info(f"sync tree command received: {command}")
|
||||
|
||||
if command == "sync.copy_global":
|
||||
await self.tree.copy_global_to(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.clear_guild":
|
||||
self.tree.clear_commands(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.guild":
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
else:
|
||||
synced = await self.tree.sync()
|
||||
|
||||
logger.info(f"Synced {len(synced)} commands")
|
||||
await message.reply(f"Synced {len(synced)} commands")
|
||||
|
||||
async def handle_command(self, message: discord.Message, is_owner: bool):
|
||||
command_text: str = message.content
|
||||
command_text = command_text[1:]
|
||||
match command_text:
|
||||
case "help" | "?":
|
||||
await self.post_help(user=message.author)
|
||||
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
|
||||
if is_owner:
|
||||
await self._sync(command_text, message)
|
||||
case _:
|
||||
await message.reply(f"unknown command: {command_text}")
|
||||
|
||||
def recipient_filter(self, message: discord.Message) -> bool:
|
||||
channel = message.channel
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.private
|
||||
or message.channel.type == discord.ChannelType.private_thread
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.text
|
||||
or message.channel.type == discord.ChannelType.public_thread
|
||||
):
|
||||
while channel:
|
||||
if self.bot_channel and channel.id == self.bot_channel.id:
|
||||
return True
|
||||
channel = channel.parent
|
||||
|
||||
return False
|
||||
|
||||
async def handle_message(self, message: discord.Message):
|
||||
if not self.recipient_filter(message):
|
||||
return
|
||||
|
||||
user_id = message.author.id
|
||||
user_display_name = message.author.name
|
||||
|
||||
logger.debug(
|
||||
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
|
||||
)
|
||||
|
||||
command_prefix = "!"
|
||||
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
|
||||
is_owner = self.owner_id and user_id == self.owner_id
|
||||
await self.handle_command(message, is_owner)
|
||||
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
handler = self.reply_handlers.get(message.channel.id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
if message.reference:
|
||||
handler = self.reply_handlers.get(message.reference.message_id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
async def remove_completed_handlers(self):
|
||||
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
|
||||
if len(completed) == 0:
|
||||
return
|
||||
|
||||
for c in completed:
|
||||
handler = self.reply_handlers[c]
|
||||
del self.reply_handlers[c]
|
||||
try:
|
||||
await handler.handler.finalize()
|
||||
except Exception:
|
||||
logger.exception("handler finalize failed")
|
||||
|
||||
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
|
||||
|
||||
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
||||
for channel in self.client.get_all_channels():
|
||||
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
||||
return channel
|
||||
|
||||
def run(self):
|
||||
"""Run bot loop blocking."""
|
||||
self.client.run(self.bot_token)
|
||||
@@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from channel_handlers import ChannelHandlerBase
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplyHandlerInfo:
|
||||
msg_id: int
|
||||
handler_task: asyncio.Task
|
||||
handler: ChannelHandlerBase
|
||||
|
||||
|
||||
class BotBase(ABC):
|
||||
bot_channel_name: str
|
||||
debug: bool
|
||||
backend: ApiClient
|
||||
client: discord.Client
|
||||
loop: asyncio.BaseEventLoop
|
||||
owner_id: int
|
||||
bot_channel: discord.TextChannel
|
||||
templates: MessageTemplates
|
||||
reply_handlers: dict[int, ReplyHandlerInfo]
|
||||
|
||||
def __init__(self):
|
||||
self.reply_handlers = {} # handlers by msg_id
|
||||
|
||||
def ensure_bot_channel(self) -> None:
|
||||
if self.bot_channel is None:
|
||||
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
|
||||
|
||||
async def post(
|
||||
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
|
||||
) -> discord.Message:
|
||||
if channel is None:
|
||||
self.ensure_bot_channel()
|
||||
channel = self.bot_channel
|
||||
return await channel.send(content=content, view=view)
|
||||
|
||||
async def post_template(
|
||||
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
|
||||
) -> discord.Message:
|
||||
logger.debug(f"rendering {name}")
|
||||
text = self.templates.render(name, **kwargs)
|
||||
return await self.post(text, view=view, channel=channel)
|
||||
|
||||
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
|
||||
if msg_id in self.reply_handlers:
|
||||
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
|
||||
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
|
||||
task.add_done_callback(lambda t: handler.on_completed())
|
||||
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
|
||||
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import AnyHttpUrl, BaseSettings
|
||||
|
||||
|
||||
class BotSettings(BaseSettings):
|
||||
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
|
||||
API_KEY: str = "any_key"
|
||||
BOT_TOKEN: str
|
||||
BOT_CHANNEL_NAME: str = "bot"
|
||||
OWNER_ID: int = None
|
||||
TEMPLATE_DIR: str = "./templates"
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
settings = BotSettings(_env_file=".env")
|
||||
@@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ChannelExpiredException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChannelHandlerBase(ABC):
|
||||
queue: asyncio.Queue
|
||||
completed: bool = False
|
||||
expiry_date: datetime
|
||||
expired: bool = False
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
self.expiry_date = expiry_date
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
"""Call this method to read the next message from the user in the handler method."""
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
|
||||
msg = await self.queue.get()
|
||||
if msg is None:
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
else:
|
||||
raise RuntimeError("Unexpected None message read")
|
||||
return msg
|
||||
|
||||
def on_reply(self, message: discord.Message) -> None:
|
||||
self.queue.put_nowait(message)
|
||||
|
||||
def on_expire(self) -> None:
|
||||
logger.info("ChannelHandler: on_expire")
|
||||
self.expired = True
|
||||
self.queue.put_nowait(None)
|
||||
|
||||
def on_completed(self) -> None:
|
||||
logger.info("ChannelHandler: on_completed")
|
||||
self.completed = True
|
||||
|
||||
def tick(self, now: datetime):
|
||||
if now > self.expiry_date and not self.expired:
|
||||
self.on_expire()
|
||||
|
||||
@abstractmethod
|
||||
async def handler_loop(self):
|
||||
...
|
||||
|
||||
async def finalize(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoDestructThreadHandler(ChannelHandlerBase):
|
||||
first_message: discord.Message = None
|
||||
thread: discord.Thread = None
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
super().__init__(expiry_date=expiry_date)
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
try:
|
||||
return await super().read()
|
||||
except ChannelExpiredException:
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self):
|
||||
logger.debug("AutoDestructThreadHandler.cleanup")
|
||||
if self.thread:
|
||||
logger.debug(f"deleting thread: {self.thread.name}")
|
||||
await self.thread.delete()
|
||||
self.thread = None
|
||||
if self.first_message:
|
||||
logger.debug(f"deleting first_message: {self.first_message.content}")
|
||||
await self.first_message.delete()
|
||||
self.first_message = None
|
||||
|
||||
async def finalize(self):
|
||||
await self.cleanup()
|
||||
return await super().finalize()
|
||||
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import jinja2
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageTemplates:
|
||||
def __init__(self, template_dir="./templates"):
|
||||
self.env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(template_dir),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
|
||||
)
|
||||
|
||||
def render(self, template_name, **kwargs):
|
||||
template = self.env.get_template(template_name)
|
||||
txt = template.render(kwargs)
|
||||
logger.debug(txt)
|
||||
|
||||
return txt
|
||||
@@ -0,0 +1,7 @@
|
||||
discord.py==2.1.0
|
||||
Jinja2==3.1.2
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
pytz==2022.7
|
||||
requests==2.28.1
|
||||
schedule==1.1.0
|
||||
@@ -0,0 +1,267 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from bot_base import BotBase
|
||||
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
||||
|
||||
|
||||
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
||||
name = discord.ui.TextInput(label="Name")
|
||||
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
|
||||
|
||||
|
||||
class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
thread_name: str = "Replies"
|
||||
expires_after: timedelta = timedelta(minutes=5)
|
||||
backend: ApiClient
|
||||
|
||||
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
|
||||
try:
|
||||
self.bot = bot
|
||||
self.task = task
|
||||
self.backend = bot.backend
|
||||
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
||||
msg = await self.send_first_message()
|
||||
self.first_message = msg
|
||||
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
||||
await self.on_thread_created(self.thread)
|
||||
except Exception:
|
||||
logger.exception("start task failed")
|
||||
await self.cleanup() # try to cleanup messag or thread
|
||||
raise
|
||||
|
||||
bot.register_reply_handler(msg_id=msg.id, handler=self)
|
||||
return msg
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_first_message(self) -> discord.message:
|
||||
...
|
||||
|
||||
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
||||
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
||||
|
||||
async def post_teaser_msg(self, template_name: str):
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
return await self.bot.post_template(
|
||||
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
|
||||
)
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
api_response = await self.backend.post_interaction(interaction)
|
||||
if api_response.type != "task_done":
|
||||
# multi-step tasks are not supported yet
|
||||
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
|
||||
raise RuntimeError("Unexpected response from backend received")
|
||||
return api_response
|
||||
|
||||
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.TextReplyToPost(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
text=user_msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
self.post_text_reply_to_post(user_msg)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_text_reply_to_post()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.PostRanking(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
ranking=ranking,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
ranking_str = user_msg.content
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
self.post_ranking(user_msg, ranking=ranking)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_ranking()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
|
||||
class SummarizeStoryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.SummarizeStoryTask
|
||||
thread_name: str = "Summaries"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_summarize_story.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class InitialPromptHandler(ChannelTaskBase):
|
||||
task: protocol_schema.InitialPromptTask
|
||||
thread_name: str = "Prompts"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_initial_prompt.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class UserReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.UserReplyTask
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_user_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class AssistantReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.AssistantReplyTask
|
||||
thread_name: str = "Assistant replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_assistant_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankInitialPromptsTask
|
||||
thread_name: str = "User Responses"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RankConversationsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankConversationRepliesTask
|
||||
thread_name: str = "Rankings"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RatingButton(discord.ui.Button):
|
||||
def __init__(self, label, value, response_handler):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green)
|
||||
self.value = value
|
||||
self.response_handler = response_handler
|
||||
|
||||
async def callback(self, interaction):
|
||||
await self.response_handler(self.value, interaction)
|
||||
|
||||
|
||||
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
view = discord.ui.View()
|
||||
for i in range(lo, hi + 1):
|
||||
view.add_item(RatingButton(str(i), i, response_handler))
|
||||
return view
|
||||
|
||||
|
||||
class RateSummaryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RateSummaryTask
|
||||
thread_name: str = "Ratings"
|
||||
|
||||
async def _rating_response_handler(self, score, interaction: discord.Interaction):
|
||||
logger.info("rating_response_handler", score)
|
||||
if self.thread:
|
||||
try:
|
||||
self.backend.post_interaction(
|
||||
protocol_schema.PostRating(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(interaction.id),
|
||||
user=self.to_api_user(interaction.user),
|
||||
rating=score,
|
||||
)
|
||||
)
|
||||
await interaction.response.send_message(
|
||||
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
|
||||
)
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in _rating_response_handler()")
|
||||
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rate_summary.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
||||
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info(f"on_rate_summary_reply: {msg.content}")
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply("❌ Text intput not supported.")
|
||||
@@ -0,0 +1,13 @@
|
||||
```
|
||||
________ __
|
||||
\_____ \ _____ ______ _______/ |_
|
||||
/ | \\__ \ / ___// ___/\ __\
|
||||
/ | \/ __ \_\___ \ \___ \ | |
|
||||
\_______ (____ /____ >____ > |__|
|
||||
\/ \/ \/ \/
|
||||
|
||||
{{bot_name}} {{version}}
|
||||
git hash: {{git_hash}}
|
||||
debug_mode: {{debug}}
|
||||
```
|
||||
https://github.com/LAION-AI/Open-Assistant
|
||||
@@ -0,0 +1,15 @@
|
||||
**Open-Assistant Bot Help**
|
||||
|
||||
Available slash-commands:
|
||||
|
||||
`/work` Requests a new personalized human feedback task
|
||||
`/help` Show this message
|
||||
|
||||
{% if is_bot_owner %}
|
||||
Commands for bot owners:
|
||||
|
||||
`!sync`
|
||||
`!sync.guild`
|
||||
`!sync.copy_global`
|
||||
`!sync.clear_guild`
|
||||
{% endif %}
|
||||
@@ -0,0 +1,12 @@
|
||||
Act as the assistant and reply to the user.
|
||||
Here is the conversation so far:
|
||||
{% for message in task.conversation.messages %}
|
||||
{% if message.is_assistant %}
|
||||
:robot: Assistant:
|
||||
{{ message.text }}
|
||||
{% else %}
|
||||
:person_red_hair: User:
|
||||
**{{ message.text }}**"
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
:robot: Assistant: { human, pls help me! ... }
|
||||
@@ -0,0 +1,4 @@
|
||||
Please provide an initial prompt to the assistant.
|
||||
{% if task.hint is not none %}
|
||||
Hint: {{task.hint}}
|
||||
{% endif %}
|
||||
@@ -0,0 +1,13 @@
|
||||
Here is the conversation so far:
|
||||
{% for message in task.conversation.messages %}{% if message.is_assistant %}
|
||||
:robot: Assistant:
|
||||
{{ message.text }}
|
||||
{% else %}
|
||||
:person_red_hair: User:
|
||||
**{{ message.text }}**"
|
||||
{% endif %}{% endfor %}
|
||||
Rank the following replies:
|
||||
{% for reply in task.replies %}
|
||||
{{loop.index}}: {{reply}}{% endfor %}
|
||||
|
||||
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
|
||||
@@ -0,0 +1,5 @@
|
||||
Rank the following prompts:
|
||||
{% for prompt in task.prompts %}
|
||||
{{loop.index}}: {{prompt}}{% endfor %}
|
||||
|
||||
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
|
||||
@@ -0,0 +1,7 @@
|
||||
Rate the following summary:
|
||||
{{task.summary}}
|
||||
|
||||
Full text:
|
||||
{{task.full_text}}
|
||||
|
||||
Rating scale: {{task.scale.min}} - {{task.scale.max}}
|
||||
@@ -0,0 +1,2 @@
|
||||
Summarize to the following story:
|
||||
{{task.story}}
|
||||
@@ -0,0 +1,12 @@
|
||||
Please provide a reply to the assistant.
|
||||
Here is the conversation so far:
|
||||
{% for message in task.conversation.messages %}{% if message.is_assistant %}
|
||||
:robot: Assistant:
|
||||
{{ message.text }}
|
||||
{% else %}
|
||||
:person_red_hair: User:
|
||||
**{{ message.text }}**"
|
||||
{% endif %}{% endfor %}
|
||||
{% if task.hint %}
|
||||
Hint: {{ task.hint }}
|
||||
{% endif %}
|
||||
@@ -0,0 +1,3 @@
|
||||
:robot: **Challenge: Assistant Reply**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:microphone2: **Challenge: Initial Prompt**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:bar_chart: **Challenge: Rank Replies**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:bar_chart: **Challenge: Rank Initial Prompts**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:ballot_box: **Challenge: Rate Summary**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:books: **Challenge: Summarize Story**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:person_red_hair: **Challenge: User Reply**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
|
||||
@@ -0,0 +1,6 @@
|
||||
Hi there,
|
||||
|
||||
I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗!
|
||||
Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world.
|
||||
|
||||
Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands.
|
||||
@@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
def get_git_head_hash():
|
||||
# get current git hash
|
||||
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
if x.returncode == 0:
|
||||
return x.stdout.replace("\n", "")
|
||||
return None
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(pytz.UTC)
|
||||
|
||||
|
||||
class DiscordTimestampStyle(str, enum.Enum):
|
||||
"""
|
||||
Timestamp Styles
|
||||
|
||||
t 16:20 Short Time
|
||||
T 16:20:30 Long Time
|
||||
d 20/04/2021 Short Date
|
||||
D 20 April 2021 Long Date
|
||||
f * 20 April 2021 16:20 Short Date/Time
|
||||
F Tuesday, 20 April 2021 16:20 Long Date/Time
|
||||
R 2 months ago Relative Time
|
||||
|
||||
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
|
||||
"""
|
||||
|
||||
default = ""
|
||||
short_time = "t"
|
||||
long_time = "T"
|
||||
short_date = "d"
|
||||
long_date = "D"
|
||||
short_date_time = "f"
|
||||
long_date_time = "F"
|
||||
relative_time = "R"
|
||||
|
||||
|
||||
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
|
||||
parts = ["<t:", str(int(d.timestamp()))]
|
||||
if style:
|
||||
parts.append(":")
|
||||
parts.append(style)
|
||||
parts.append(">")
|
||||
return "".join(parts)
|
||||
@@ -0,0 +1,15 @@
|
||||
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10
|
||||
|
||||
COPY ./backend/requirements.txt /app/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
|
||||
|
||||
ENV PORT 8080
|
||||
|
||||
COPY ./oasst-shared /oasst-shared
|
||||
RUN pip install -e /oasst-shared
|
||||
|
||||
COPY ./backend/alembic /app/alembic
|
||||
COPY ./backend/alembic.ini /app/alembic.ini
|
||||
COPY ./backend/main.py /app/main.py
|
||||
COPY ./backend/oasst_backend /app/oasst_backend
|
||||
@@ -0,0 +1,7 @@
|
||||
FROM python:3.10-slim-bullseye
|
||||
RUN mkdir /app
|
||||
COPY ./discord-bot/requirements.txt /requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
WORKDIR /app
|
||||
COPY ./discord-bot /app
|
||||
CMD ["python", "bot.py"]
|
||||
@@ -0,0 +1,62 @@
|
||||
# Install dependencies only when needed
|
||||
FROM node:16.19 AS deps
|
||||
# Check https://github.com/nodejs/docker-node/tree/b4117f9333da4138b03a546ec926ef50a31506c3#nodealpine to understand why libc6-compat might be needed.
|
||||
# RUN apk add --no-cache libc6-compat
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies based on the preferred package manager
|
||||
COPY ./website/package.json ./website/package-lock.json ./
|
||||
RUN \
|
||||
if [ -f package-lock.json ]; then npm ci; \
|
||||
else echo "Lockfile not found." && exit 1; \
|
||||
fi
|
||||
|
||||
# Rebuild the source code only when needed
|
||||
FROM node:16.19 AS builder
|
||||
WORKDIR /app
|
||||
COPY --from=deps /app/node_modules ./node_modules
|
||||
COPY ./website/ .
|
||||
|
||||
# Next.js collects completely anonymous telemetry data about general usage.
|
||||
# Learn more here: https://nextjs.org/telemetry
|
||||
# Uncomment the following line in case you want to disable telemetry during the build.
|
||||
# ENV NEXT_TELEMETRY_DISABLED 1
|
||||
|
||||
# RUN yarn build
|
||||
RUN npx prisma generate
|
||||
RUN npm run build
|
||||
|
||||
# Production image, copy all the files and run next
|
||||
FROM node:16.19 AS runner
|
||||
WORKDIR /app
|
||||
|
||||
ENV NODE_ENV production
|
||||
# Uncomment the following line in case you want to disable telemetry during runtime.
|
||||
# ENV NEXT_TELEMETRY_DISABLED 1
|
||||
|
||||
RUN addgroup --system --gid 1001 nodejs
|
||||
RUN adduser --system --uid 1001 nextjs
|
||||
|
||||
COPY --from=builder /app/public ./public
|
||||
|
||||
# Copy over the prisma schema so we can to `npx prisma db push` and ensure the
|
||||
# database exists on startup.
|
||||
COPY --chown=nextjs:nodejs ./website/prisma/schema.prisma ./
|
||||
# Copy over a startup script that'll run `npx prisma db push` before starting
|
||||
# the webserver. This ensures the webserver can actually check user accounts.
|
||||
# This is a prisma variant of the postgres solution suggested in
|
||||
# https://docs.docker.com/compose/startup-order/
|
||||
COPY --chown=nextjs:nodejs ./website/wait-for-postgres.sh ./
|
||||
|
||||
# Automatically leverage output traces to reduce image size
|
||||
# https://nextjs.org/docs/advanced-features/output-file-tracing
|
||||
COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone ./
|
||||
COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
|
||||
|
||||
USER nextjs
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
ENV PORT 3000
|
||||
|
||||
CMD ["node", "server.js"]
|
||||
@@ -0,0 +1,3 @@
|
||||
# Shared Python code for Open Assisstant
|
||||
|
||||
Run `pip install -e .` to install the package in editable mode.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user