mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
Merge branch 'patch-1' of github.com:markcheeky/Open-Assistant into patch-1
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"service": "frontend-dev",
|
||||
"dockerComposeFile": "../docker-compose.yaml",
|
||||
"forwardPorts": [3000],
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": ["GitHub.copilot"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
**/node_modules
|
||||
@@ -12,7 +12,7 @@ on:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-frontend:
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16.x
|
||||
cache: 'npm'
|
||||
cache: "npm"
|
||||
cache-dependency-path: website/package-lock.json
|
||||
- run: npm ci
|
||||
- run: npm run build
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
name: Test API Contract
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
test-contract:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- run: cd oasst-shared && pip install -e .
|
||||
|
||||
- run: cd oasst-shared && pip install -r requirements.dev.txt
|
||||
|
||||
- run: cd backend && pip install -r requirements.txt
|
||||
|
||||
- run: ./scripts/backend-development/start-mock-server.sh
|
||||
|
||||
- name: Run contract tests
|
||||
run: ./scripts/oasst-shared-development/test.sh
|
||||
|
||||
- run: ./scripts/backend-development/stop-mock-server.sh
|
||||
@@ -5,3 +5,6 @@
|
||||
*.egg-info
|
||||
__pycache__
|
||||
.DS_Store
|
||||
|
||||
# Generated files
|
||||
backend/oasst-openapi.json
|
||||
|
||||
+36
-9
@@ -1,7 +1,32 @@
|
||||
exclude: "build|stubs|^bot/templates/|^notebooks/.*\\.ipynb$"
|
||||
# WARNING!
|
||||
#
|
||||
# When making changes to auto-formatters used in pre-commit hooks, you are
|
||||
# likely to cause merge conflicts with main and/or other pull requests.
|
||||
# Fixing them might revert other people's work. Expect pain!
|
||||
# To avoid accidental reversions and keep it easy to review, please make sure
|
||||
# that changes here are in a pull request by themselves, that it consists of
|
||||
# two commits:
|
||||
#
|
||||
# 1. The changes to this file
|
||||
# 2. Changes made by running `python3 -m pre_commit run --all-files`.
|
||||
#
|
||||
# Then each time your pull request is blocked by a merge conflict, do the
|
||||
# following steps:
|
||||
#
|
||||
# git reset HEAD^1 && git checkout -f # discard the change commit
|
||||
# git rebase main # re-apply other people's changes
|
||||
# python3 -m pre_commit run --all-files # re-run the rules
|
||||
# git add . # add the newly changed files
|
||||
# git commit -m 'apply pre-commit' # commit it
|
||||
# git push -f # force push back to your branch
|
||||
#
|
||||
# Keep in mind you may have to do this a few times, as changes here may impact
|
||||
# other pull requests. Try to keep it up-to-date so they can go in when it'll
|
||||
# cause least disruption.
|
||||
#
|
||||
# /WARNING!
|
||||
|
||||
default_language_version:
|
||||
python: python3
|
||||
exclude: build|stubs|^bot/templates/$
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
@@ -14,11 +39,12 @@ repos:
|
||||
# and which break the standard YAML check. The alternative would be to
|
||||
# skip any unsafe errors (and thus break YAML compatibility) or use
|
||||
# some other checker that may not work in general.
|
||||
exclude: "^copilot/web/addons/.*$"
|
||||
exclude: ^copilot/.*/addons/.*$
|
||||
- id: check-json
|
||||
- id: check-case-conflict
|
||||
- id: detect-private-key
|
||||
- id: fix-encoding-pragma
|
||||
args: [--remove]
|
||||
- id: forbid-submodules
|
||||
- id: mixed-line-ending
|
||||
- id: requirements-txt-fixer
|
||||
@@ -28,13 +54,13 @@ repos:
|
||||
- id: check-symlinks
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=1024"]
|
||||
args: [--maxkb=1024]
|
||||
- id: end-of-file-fixer
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.12.0
|
||||
hooks:
|
||||
- id: black
|
||||
- id: black-jupyter
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
@@ -50,14 +76,15 @@ repos:
|
||||
rev: v2.7.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
args: ["--write"]
|
||||
args: [--prose-wrap=always, --write]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: next-lint-website
|
||||
name: Lint website
|
||||
files: ^website/
|
||||
exclude: ^website/node_modules/
|
||||
types_or: [javascript, jsx, ts, tsx]
|
||||
language: system
|
||||
language: node
|
||||
pass_filenames: false
|
||||
entry: bash -c "cd website && npm install && npm run lint"
|
||||
entry: website/next-lint.js
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
* @yk @andreaskoepf
|
||||
/website/ @fozziethebeat @k-nearest-neighbor @AbdBarho
|
||||
/model/ @theblackcat102 @sanagno
|
||||
/copilot/ @fozziethebeat @andreaskoepf @yk
|
||||
|
||||
@@ -1,37 +1,63 @@
|
||||
# Open-Assistant
|
||||
|
||||
Open Assistant is a project meant to give everyone access to a great chat based large language model.
|
||||
Open Assistant is a project meant to give everyone access to a great chat based
|
||||
large language model.
|
||||
|
||||
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
|
||||
We believe that by doing this we will create a revolution in innovation in
|
||||
language. In the same way that stable-diffusion helped the world make art and
|
||||
images in new ways we hope Open Assistant can help improve the world by
|
||||
improving language itself.
|
||||
|
||||
## Do you want to try it out?
|
||||
|
||||
If you are interested in taking a look at the current state of the project, You can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
If you are interested in taking a look at the current state of the project, you
|
||||
can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
website, backend, and associated dependent services.
|
||||
|
||||
To start the demo, Run this in the root directory of the repository:
|
||||
To start the demo, run this in the root directory of the repository:
|
||||
|
||||
```sh
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and interact with the website.
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
|
||||
interact with the website.
|
||||
|
||||
**Note:** When logging in via email, navigate to `http://localhost:1080` to get the magic email login link.
|
||||
**Note:** When logging in via email, navigate to `http://localhost:1080` to get
|
||||
the magic email login link.
|
||||
|
||||
## The Plan
|
||||
|
||||
We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
|
||||
We want to get to an initial MVP as fast as possible, by following the 3-steps
|
||||
outlined in the InstructGPT paper.
|
||||
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples (prompt + response), goal >50k. We design a crowdsourced process to collect and reviewed prompts. We do not want to train on flooding/toxic/spam/junk/personal information data. We will have a leaderboard to motivate the community that shows progress and the most active users. Swag will be given to the top-contributors.
|
||||
2. For each of the collected prompts we will sample multiple completions. Completions of one prompt will then be shown randomly to users to rank them from best to worst. Again this should happen crowd-sourced, e.g. we need to deal with unreliable potentially malicious users. At least multiple votes by independent users have to be collected to measure the overall agreement. The gathered ranking-data will be used to train a reward model.
|
||||
3. Now follows the RLHF training phase based on the prompts and the reward model.
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples
|
||||
(prompt + response), goal >50k. We design a crowdsourced process to collect
|
||||
and reviewed prompts. We do not want to train on
|
||||
flooding/toxic/spam/junk/personal information data. We will have a
|
||||
leaderboard to motivate the community that shows progress and the most active
|
||||
users. Swag will be given to the top-contributors.
|
||||
2. For each of the collected prompts we will sample multiple completions.
|
||||
Completions of one prompt will then be shown randomly to users to rank them
|
||||
from best to worst. Again this should happen crowd-sourced, e.g. we need to
|
||||
deal with unreliable potentially malicious users. At least multiple votes by
|
||||
independent users have to be collected to measure the overall agreement. The
|
||||
gathered ranking-data will be used to train a reward model.
|
||||
3. Now follows the RLHF training phase based on the prompts and the reward
|
||||
model.
|
||||
|
||||
We can then take the resulting model and continue with completion sampling step 2 for a next iteration.
|
||||
We can then take the resulting model and continue with completion sampling step
|
||||
2 for a next iteration.
|
||||
|
||||
## The Vision
|
||||
|
||||
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
|
||||
We are not going to stop at replicating ChatGPT. We want to build the assistant
|
||||
of the future, able to not only write email and cover letters, but do meaningful
|
||||
work, use APIs, dynamically research information, and much more, with the
|
||||
ability to be personalized and extended by anyone. And we want to do this in a
|
||||
way that is open and accessible, which means we must not only build a great
|
||||
assistant, but also make it small and efficient enough to run on consumer
|
||||
hardware.
|
||||
|
||||
### Slide Decks
|
||||
|
||||
@@ -41,15 +67,20 @@ We are not going to stop at replicating ChatGPT. We want to build the assistant
|
||||
|
||||
## How can you help?
|
||||
|
||||
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
|
||||
All open source projects begin with people like you. Open source is the belief
|
||||
that if we collaborate we can together gift our knowledge and technology to the
|
||||
world for the benefit of humanity.
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord), this is for work coordination.
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord),
|
||||
this is for work coordination.
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has a dedicated channel and is more public.
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has
|
||||
a dedicated channel and is more public.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a dedicated, but not as active, channel.
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a
|
||||
dedicated, but not as active, channel.
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
@@ -57,15 +88,16 @@ All open source projects begins with people like you. Open source is the belief
|
||||
|
||||
We have a growing task list
|
||||
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
|
||||
that appeals to you and make a comment that you'd like to work on it. Include
|
||||
in your comment a brief description of how you'll solve the problem and if
|
||||
there are any open questions you want to discuss. Once a project coordinator
|
||||
has assigned the issue to you, start working on it.
|
||||
that appeals to you and make a comment that you'd like to work on it. Include in
|
||||
your comment a brief description of how you'll solve the problem and if there
|
||||
are any open questions you want to discuss. Once a project coordinator has
|
||||
assigned the issue to you, start working on it.
|
||||
|
||||
If the issue is currently unclear but you are interested, please post in
|
||||
Discord and someone can help clarify the issue with more detail.
|
||||
If the issue is currently unclear but you are interested, please post in Discord
|
||||
and someone can help clarify the issue with more detail.
|
||||
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of the system architecture, and other documentation.
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of
|
||||
the system architecture, and other documentation.
|
||||
|
||||
### Submitting Work
|
||||
|
||||
@@ -73,8 +105,8 @@ We're all working on different parts of Open Assistant together. To make
|
||||
contributions smoothly we recommend the following:
|
||||
|
||||
1. [Fork this project repository](https://docs.github.com/en/get-started/quickstart/fork-a-repo)
|
||||
and clone it to your local machine.
|
||||
(Read more [About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
|
||||
and clone it to your local machine. (Read more
|
||||
[About Forks](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks))
|
||||
1. Before working on any changes, try to
|
||||
[sync the forked repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork)
|
||||
to keep it up-to-date with the upstream repository.
|
||||
@@ -83,7 +115,8 @@ contributions smoothly we recommend the following:
|
||||
simplifies life for reviewers.
|
||||
1. Package up a small bit of work that solves part of the problem
|
||||
[into a Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)
|
||||
and [send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
|
||||
and
|
||||
[send it out for review](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review).
|
||||
1. If you're lucky, we can merge your change into `main` without any problems.
|
||||
If there's changes to files you're working on, resolve them by:
|
||||
1. First try rebase as suggested
|
||||
@@ -108,20 +141,27 @@ addressed now, or filing an issue to handle it later.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
Work is organized in the [project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
Work is organized in the
|
||||
[project board](https://github.com/orgs/LAION-AI/projects/3).
|
||||
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs. Meaning we'd be happy for anyone to do these tasks.**
|
||||
**Anything that is in the `Todo` column and not assigned, is up for grabs.
|
||||
Meaning we'd be happy for anyone to do these tasks.**
|
||||
|
||||
If you want to work on something, assign yourself to it or write a comment that you want to work on it and what you plan to do.
|
||||
If you want to work on something, assign yourself to it or write a comment that
|
||||
you want to work on it and what you plan to do.
|
||||
|
||||
- To get started with development, if you want to work on the backend, have a look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at `scripts/frontend-development/README.md` to make a backend available.
|
||||
- To get started with development, if you want to work on the backend, have a
|
||||
look at `scripts/backend-development/README.md`.
|
||||
- If you want to work on any frontend, have a look at
|
||||
`scripts/frontend-development/README.md` to make a backend available.
|
||||
|
||||
There is also a minimal implementation of a frontend in the `text-frontend` folder.
|
||||
There is also a minimal implementation of a frontend in the `text-frontend`
|
||||
folder.
|
||||
|
||||
We are using Python 3.10 for the backend.
|
||||
|
||||
Check out the [High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
Check out the
|
||||
[High-Level Protocol Architecture](https://www.notion.so/High-Level-Protocol-Architecture-6f1fd3551da74213b560ead369f132dc)
|
||||
|
||||
### Website
|
||||
|
||||
@@ -129,10 +169,15 @@ The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit hooks.
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
|
||||
Upon making a release on GitHub, all docker images are automatically built and
|
||||
pushed to ghcr.io. The docker images are tagged with the release version, and
|
||||
the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to
|
||||
automatically deploy the built release to the dev machine.
|
||||
|
||||
+5
-2
@@ -2,7 +2,9 @@
|
||||
|
||||
## REST Server Configuration
|
||||
|
||||
Please either use environment variables or create a `.env` file in the backend root directory (in which this readme file is located) to specify the `DATABASE_URI`.
|
||||
Please either use environment variables or create a `.env` file in the backend
|
||||
root directory (in which this readme file is located) to specify the
|
||||
`DATABASE_URI`.
|
||||
|
||||
Example contents of a `.env` file for the backend:
|
||||
|
||||
@@ -14,4 +16,5 @@ BACKEND_CORS_ORIGINS=["http://localhost", "http://localhost:4200", "http://local
|
||||
|
||||
## Running the REST Server locally for development
|
||||
|
||||
Have a look into the main `README.md` file for more information on how to set up the backend for development.
|
||||
Have a look into the main `README.md` file for more information on how to set up
|
||||
the backend for development.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from logging.config import fileConfig
|
||||
|
||||
import sqlmodel
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""first revision
|
||||
|
||||
Revision ID: 23e5fea252dd
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""v1 db structure
|
||||
|
||||
Revision ID: cd7de470586e
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add auth_method to person
|
||||
|
||||
Revision ID: 6368515778c5
|
||||
|
||||
-1
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_auth_method_to_ix_person_username
|
||||
|
||||
Revision ID: 0daec5f8135f
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Adds text labels table.
|
||||
|
||||
Revision ID: 067c4002f2d9
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_journal_table
|
||||
|
||||
Revision ID: 3358eb6834e6
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""post ref for work_package
|
||||
|
||||
Revision ID: d24b37426857
|
||||
|
||||
-1
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Added lang column for ISO-639-1 codes
|
||||
|
||||
Revision ID: ef0b52902560
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add collective flag to task
|
||||
|
||||
Revision ID: 464ec4667aae
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add field trusted api client
|
||||
|
||||
Revision ID: 73ce3675c1f5
|
||||
|
||||
-1
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""name changes: person->user, post->message, work_package->task
|
||||
|
||||
Revision ID: abb47e9d145a
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add deleted field to post
|
||||
|
||||
Revision ID: 8d269bc4fdbd
|
||||
|
||||
+63
-3
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -7,13 +7,15 @@ import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
@@ -24,8 +26,13 @@ app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V
|
||||
@app.exception_handler(OasstError)
|
||||
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
|
||||
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")
|
||||
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
|
||||
status_code=int(ex.http_status_code),
|
||||
content=protocol_schema.OasstErrorResponse(
|
||||
message=ex.message,
|
||||
error_code=OasstErrorCode(ex.error_code),
|
||||
).dict(),
|
||||
)
|
||||
|
||||
|
||||
@@ -63,6 +70,29 @@ if settings.UPDATE_ALEMBIC:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
|
||||
|
||||
if settings.RATE_LIMIT:
|
||||
|
||||
@app.on_event("startup")
|
||||
async def connect_redis():
|
||||
async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
|
||||
"""Error callback function when too many requests"""
|
||||
expire = ceil(pexpire / 1000)
|
||||
raise OasstError(
|
||||
f"Too Many Requests. Retry After {expire} seconds.",
|
||||
OasstErrorCode.TOO_MANY_REQUESTS,
|
||||
HTTPStatus.TOO_MANY_REQUESTS,
|
||||
)
|
||||
|
||||
try:
|
||||
redis_client = redis.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True
|
||||
)
|
||||
logger.info(f"Connected to {redis_client=}")
|
||||
await FastAPILimiter.init(redis_client, http_callback=http_callback)
|
||||
except Exception:
|
||||
logger.exception("Failed to establish Redis connection")
|
||||
|
||||
|
||||
if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -179,3 +209,33 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
return json.dumps(app.openapi())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Importing here so we don't import packages unnecessarily if we're
|
||||
# importing main as a module.
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--print-openapi-schema",
|
||||
help="Dumps the openapi schema to stdout",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server")
|
||||
parser.add_argument("--port", help="The port to run the server")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.print_openapi_schema:
|
||||
print(get_openapi_schema())
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Security
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
@@ -85,3 +85,58 @@ def get_trusted_api_client(
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class UserRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
) -> None:
|
||||
async def identifier(request: Request) -> str:
|
||||
"""Identify a request based on api_key and user.id"""
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
user = (await request.json()).get("user")
|
||||
return f"{api_key}:{user.get('id')}"
|
||||
|
||||
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
|
||||
|
||||
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.RATE_LIMIT:
|
||||
return
|
||||
|
||||
# Attempt to retrieve api_key and user information
|
||||
user = (await request.json()).get("user")
|
||||
|
||||
# Skip when api_key and user information are not available
|
||||
# (such that it will be handled by `APIClientRateLimiter`)
|
||||
if not api_key or not user or not user.get("id"):
|
||||
return
|
||||
|
||||
return await super().__call__(request, response)
|
||||
|
||||
|
||||
class APIClientRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
) -> None:
|
||||
async def identifier(request: Request) -> str:
|
||||
"""Identify a request based on api_key and user.id"""
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
return f"{api_key}"
|
||||
|
||||
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
|
||||
|
||||
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.RATE_LIMIT:
|
||||
return
|
||||
|
||||
# Attempt to retrieve api_key and user information
|
||||
user = (await request.json()).get("user")
|
||||
|
||||
# Skip if user information is available
|
||||
# (such that it will be handled by `UserRateLimiter`)
|
||||
if not api_key or user:
|
||||
return
|
||||
|
||||
return await super().__call__(request, response)
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, stats, tasks, text_labels, users
|
||||
from oasst_backend.api.v1 import (
|
||||
frontend_messages,
|
||||
frontend_users,
|
||||
leaderboards,
|
||||
messages,
|
||||
stats,
|
||||
tasks,
|
||||
text_labels,
|
||||
users,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
@@ -10,3 +18,4 @@ api_router.include_router(frontend_messages.router, prefix="/frontend_messages",
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -29,7 +28,7 @@ def get_message_by_frontend_id(
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -43,7 +42,7 @@ def get_conv_by_frontend_id(
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -57,7 +56,7 @@ def get_tree_by_frontend_id(
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -70,7 +69,7 @@ def get_children_by_frontend_id(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -84,7 +83,7 @@ def get_descendants_by_frontend_id(
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -98,7 +97,7 @@ def get_longest_conv_by_frontend_id(
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
@@ -7,14 +6,14 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{username}/messages")
|
||||
@router.get("/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
api_client_id: UUID = None,
|
||||
@@ -44,11 +43,10 @@ def query_frontend_user_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{username}/messages")
|
||||
@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/create/assistant")
|
||||
def get_assistant_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="prompter")
|
||||
@@ -1,22 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", response_model=list[protocol.Message])
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
@@ -47,7 +46,7 @@ def query_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -63,7 +62,7 @@ def get_message(
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -76,7 +75,7 @@ def get_conv(
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -89,7 +88,7 @@ def get_tree(
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -101,7 +100,7 @@ def get_children(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -114,7 +113,7 @@ def get_descendants(
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -127,7 +126,7 @@ def get_longest_conv(
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -140,10 +139,9 @@ def get_max_children(
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", response_model=protocol.SystemStats)
|
||||
def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from typing import Any, Optional, Tuple
|
||||
from uuid import UUID
|
||||
@@ -7,10 +6,11 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -128,7 +128,14 @@ def generate_task(
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=protocol_schema.AnyTask,
|
||||
dependencies=[
|
||||
Depends(deps.UserRateLimiter(times=100, minutes=5)),
|
||||
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
|
||||
],
|
||||
) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
@@ -153,14 +160,14 @@ def request_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
@@ -179,17 +186,16 @@ def tasks_acknowledge(
|
||||
except Exception:
|
||||
logger.exception("Failed to acknowledge task.")
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack")
|
||||
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""
|
||||
The frontend reports failure to implement a task.
|
||||
"""
|
||||
@@ -204,7 +210,7 @@ def tasks_acknowledge_failure(
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
@@ -260,7 +266,7 @@ def tasks_interaction(
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
|
||||
@router.post("/close")
|
||||
@router.post("/close", response_model=protocol_schema.TaskDone)
|
||||
def close_collective_task(
|
||||
close_task_request: protocol_schema.TaskClose,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
@@ -7,7 +6,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -17,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel):
|
||||
user: protocol_schema.User
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@router.post("/", status_code=HTTP_204_NO_CONTENT)
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages")
|
||||
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
@@ -42,19 +41,13 @@ def query_user_messages(
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{user_id}/messages")
|
||||
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseSettings, PostgresDsn, validator
|
||||
@@ -15,6 +14,10 @@ class Settings(BaseSettings):
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
RATE_LIMIT: bool = True
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: str = "6379"
|
||||
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
__all__ = []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlmodel import create_engine
|
||||
|
||||
if settings.DATABASE_URI is None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Literal
|
||||
|
||||
from oasst_backend.models.payload_column_type import payload_type
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid1, uuid4
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import random
|
||||
from collections import defaultdict
|
||||
@@ -8,12 +7,12 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
@@ -706,3 +705,24 @@ class PromptRepository:
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
alembic==1.8.1
|
||||
fastapi==0.88.0
|
||||
fastapi-limiter==0.1.5
|
||||
loguru==0.6.0
|
||||
numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
|
||||
+6
-6
@@ -16,8 +16,8 @@ Setup requires a few steps:
|
||||
copilot app init --domain your_domain.com
|
||||
```
|
||||
|
||||
This will initialize and register a variety of URLs with your
|
||||
`your_domain.com`. Replace with a proper domain to setup SSL certificates.
|
||||
This will initialize and register a variety of URLs with your `your_domain.com`.
|
||||
Replace with a proper domain to setup SSL certificates.
|
||||
|
||||
```sh
|
||||
copilot env deploy
|
||||
@@ -29,10 +29,10 @@ This will create a variety of aws roles and services needed for deployment.
|
||||
copilot deploy
|
||||
```
|
||||
|
||||
This will depoy the services but it won't be 100% ready for usage. Before
|
||||
being ready, we have to inspect the AWS Secrets manager and extract out the
|
||||
database credentials. Read those credentials then put them, and a few other
|
||||
secrets, in a `secrets.yml` file like the following:
|
||||
This will depoy the services but it won't be 100% ready for usage. Before being
|
||||
ready, we have to inspect the AWS Secrets manager and extract out the database
|
||||
credentials. Read those credentials then put them, and a few other secrets, in a
|
||||
`secrets.yml` file like the following:
|
||||
|
||||
```yaml
|
||||
DATABASE_URL:
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# The manifest for the "api" service.
|
||||
# Read the full specification for the "Load Balanced Web Service" type at:
|
||||
# https://aws.github.io/copilot-cli/docs/manifest/lb-web-service/
|
||||
|
||||
name: api
|
||||
type: Load Balanced Web Service
|
||||
|
||||
http:
|
||||
path: "/"
|
||||
healthcheck:
|
||||
path: "/docs"
|
||||
|
||||
image:
|
||||
build:
|
||||
dockerfile: docker/Dockerfile.backend
|
||||
context: ./
|
||||
port: 8080
|
||||
|
||||
cpu: 256
|
||||
memory: 512
|
||||
platform: linux/x86_64
|
||||
count: 1
|
||||
exec: true
|
||||
network:
|
||||
connect: true
|
||||
|
||||
environments:
|
||||
staging:
|
||||
variables:
|
||||
# Note: this has to be a valid JSON list for Pydantic to parse it.
|
||||
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
|
||||
DEBUG_ALLOW_ANY_API_KEY: True
|
||||
DEBUG_SKIP_API_KEY_CHECK: True
|
||||
MAX_WORKERS: 1
|
||||
|
||||
secrets:
|
||||
# Note: URI, not URL.
|
||||
DATABASE_URI: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/API_DATABASE_URL
|
||||
@@ -1,144 +0,0 @@
|
||||
Parameters:
|
||||
App:
|
||||
Type: String
|
||||
Description: Your application's name.
|
||||
Env:
|
||||
Type: String
|
||||
Description: The environment name your service, job, or workflow is being deployed to.
|
||||
Name:
|
||||
Type: String
|
||||
Description: The name of the service, job, or workflow being deployed.
|
||||
# Customize your Aurora Serverless cluster by setting the default value of the following parameters.
|
||||
webclusterDBName:
|
||||
Type: String
|
||||
Description: The name of the initial database to be created in the Aurora Serverless v2 cluster.
|
||||
Default: oassist_web
|
||||
# Cannot have special characters
|
||||
# Naming constraints: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html#RDS_Limits.Constraints
|
||||
Mappings:
|
||||
webclusterEnvScalingConfigurationMap:
|
||||
staging:
|
||||
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
|
||||
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
|
||||
|
||||
All:
|
||||
"DBMinCapacity": 0.5 # AllowedValues: from 0.5 through 128
|
||||
"DBMaxCapacity": 8 # AllowedValues: from 0.5 through 128
|
||||
|
||||
Resources:
|
||||
webclusterDBSubnetGroup:
|
||||
Type: "AWS::RDS::DBSubnetGroup"
|
||||
Properties:
|
||||
DBSubnetGroupDescription: Group of Copilot private subnets for Aurora Serverless v2 cluster.
|
||||
SubnetIds:
|
||||
!Split [",", { "Fn::ImportValue": !Sub "${App}-${Env}-PrivateSubnets" }]
|
||||
webclusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A security group for your workload to access the Aurora Serverless v2 cluster webcluster"
|
||||
Type: "AWS::EC2::SecurityGroup"
|
||||
Properties:
|
||||
GroupDescription: !Sub "The Security Group for ${Name} to access Aurora Serverless v2 cluster webcluster."
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
Tags:
|
||||
- Key: Name
|
||||
Value: !Sub "copilot-${App}-${Env}-${Name}-Aurora"
|
||||
webclusterDBClusterSecurityGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A security group for your Aurora Serverless v2 cluster webcluster"
|
||||
Type: AWS::EC2::SecurityGroup
|
||||
Properties:
|
||||
GroupDescription: The Security Group for the Aurora Serverless v2 cluster.
|
||||
SecurityGroupIngress:
|
||||
- ToPort: 5432
|
||||
FromPort: 5432
|
||||
IpProtocol: tcp
|
||||
Description: !Sub "From the Aurora Security Group of the workload ${Name}."
|
||||
SourceSecurityGroupId: !Ref webclusterSecurityGroup
|
||||
VpcId:
|
||||
Fn::ImportValue: !Sub "${App}-${Env}-VpcId"
|
||||
webclusterAuroraSecret:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A Secrets Manager secret to store your DB credentials"
|
||||
Type: AWS::SecretsManager::Secret
|
||||
Properties:
|
||||
Description: !Sub Aurora main user secret for ${AWS::StackName}
|
||||
GenerateSecretString:
|
||||
SecretStringTemplate: '{"username": "postgres"}'
|
||||
GenerateStringKey: "password"
|
||||
ExcludePunctuation: true
|
||||
IncludeSpace: false
|
||||
PasswordLength: 16
|
||||
webclusterDBClusterParameterGroup:
|
||||
Metadata:
|
||||
"aws:copilot:description": "A DB parameter group for engine configuration values"
|
||||
Type: "AWS::RDS::DBClusterParameterGroup"
|
||||
Properties:
|
||||
Description: !Ref "AWS::StackName"
|
||||
Family: "aurora-postgresql14"
|
||||
Parameters:
|
||||
client_encoding: "UTF8"
|
||||
webclusterDBCluster:
|
||||
Metadata:
|
||||
"aws:copilot:description": "The webcluster Aurora Serverless v2 database cluster"
|
||||
Type: "AWS::RDS::DBCluster"
|
||||
Properties:
|
||||
MasterUsername:
|
||||
!Join [
|
||||
"",
|
||||
[
|
||||
"{{resolve:secretsmanager:",
|
||||
!Ref webclusterAuroraSecret,
|
||||
":SecretString:username}}",
|
||||
],
|
||||
]
|
||||
MasterUserPassword:
|
||||
!Join [
|
||||
"",
|
||||
[
|
||||
"{{resolve:secretsmanager:",
|
||||
!Ref webclusterAuroraSecret,
|
||||
":SecretString:password}}",
|
||||
],
|
||||
]
|
||||
DatabaseName: !Ref webclusterDBName
|
||||
Engine: "aurora-postgresql"
|
||||
EngineVersion: "14.4"
|
||||
DBClusterParameterGroupName: !Ref webclusterDBClusterParameterGroup
|
||||
DBSubnetGroupName: !Ref webclusterDBSubnetGroup
|
||||
Port: 5432
|
||||
VpcSecurityGroupIds:
|
||||
- !Ref webclusterDBClusterSecurityGroup
|
||||
ServerlessV2ScalingConfiguration:
|
||||
# Replace "All" below with "!Ref Env" to set different autoscaling limits per environment.
|
||||
MinCapacity:
|
||||
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMinCapacity]
|
||||
MaxCapacity:
|
||||
!FindInMap [webclusterEnvScalingConfigurationMap, All, DBMaxCapacity]
|
||||
webclusterDBWriterInstance:
|
||||
Metadata:
|
||||
"aws:copilot:description": "The webcluster Aurora Serverless v2 writer instance"
|
||||
Type: "AWS::RDS::DBInstance"
|
||||
Properties:
|
||||
DBClusterIdentifier: !Ref webclusterDBCluster
|
||||
DBInstanceClass: db.serverless
|
||||
Engine: "aurora-postgresql"
|
||||
PromotionTier: 1
|
||||
AvailabilityZone: !Select
|
||||
- 0
|
||||
- !GetAZs
|
||||
Ref: AWS::Region
|
||||
|
||||
webclusterSecretAuroraClusterAttachment:
|
||||
Type: AWS::SecretsManager::SecretTargetAttachment
|
||||
Properties:
|
||||
SecretId: !Ref webclusterAuroraSecret
|
||||
TargetId: !Ref webclusterDBCluster
|
||||
TargetType: AWS::RDS::DBCluster
|
||||
Outputs:
|
||||
webclusterSecret: # injected as WEBCLUSTER_SECRET environment variable by Copilot.
|
||||
Description: "The JSON secret that holds the database username and password. Fields are 'host', 'port', 'dbname', 'username', 'password', 'dbClusterIdentifier' and 'engine'"
|
||||
Value: !Ref webclusterAuroraSecret
|
||||
webclusterSecurityGroup:
|
||||
Description: "The security group to attach to the workload."
|
||||
Value: !Ref webclusterSecurityGroup
|
||||
@@ -26,6 +26,7 @@ environments:
|
||||
staging:
|
||||
variables:
|
||||
NEXTAUTH_URL: https://web.staging.open-assistant.surfacedata.org
|
||||
FASTAPI_URL: https://api.staging.open-assistant.surfacedata.org
|
||||
|
||||
secrets:
|
||||
DATABASE_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/DATABASE_URL
|
||||
@@ -37,5 +38,4 @@ secrets:
|
||||
EMAIL_SERVER_USER: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_SERVER_USER
|
||||
EMAIL_FROM: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/EMAIL_FROM
|
||||
FASTAPI_KEY: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_KEY
|
||||
FASTAPI_URL: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/FASTAPI_URL
|
||||
NEXTAUTH_SECRET: /copilot/${COPILOT_APPLICATION_NAME}/${COPILOT_ENVIRONMENT_NAME}/secrets/NEXTAUTH_SECRET
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="./"
|
||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
|
||||
+62
-11
@@ -1,29 +1,85 @@
|
||||
# Open-Assistant Data Collection Discord Bot
|
||||
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large language model. You and other people can teach the bot how to respond to user requests by demonstration and by ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an
|
||||
assistant chat bot based on a large language model. You and other people can
|
||||
teach the bot how to respond to user requests by demonstration and by ranking
|
||||
the bot's outputs. If you want to learn more about RLHF please refer
|
||||
[to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
|
||||
## Invite official bot
|
||||
|
||||
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
|
||||
To add the official Open-Assistant data collection bot to your discord server
|
||||
[click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot%20applications.commands).
|
||||
The bot needs access to read the contents of user text messages.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the [large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
|
||||
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the
|
||||
[large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
|
||||
|
||||
### Setup
|
||||
### Bot Setup
|
||||
|
||||
To run the bot
|
||||
1. Create a new discord application at the
|
||||
[Discord Developer Portal](https://discord.com/developers/applications)
|
||||
|
||||
1. Go to the "Bot" tab and create a new bot
|
||||
|
||||
1. Scroll down to "Privileged Gateway Intents" and enable the following options:
|
||||
|
||||
- Server Members Intent
|
||||
- Presence Intent
|
||||
- Message Content Intent
|
||||
|
||||
This page also contains the bot token, which you will need to add to the `.env`
|
||||
file later.
|
||||
|
||||
2. Go to the "OAuth2" tab scroll to "Default Authorization Link"
|
||||
|
||||
3. Set "AUTHORIZATION METHOD" to "In-app Authorization"
|
||||
|
||||
4. Select the "bot" and "applications.commands" scopes
|
||||
|
||||
5. For testing and local development, it's easiest to set "BOT PERMISSIONS" to
|
||||
"Administrator"
|
||||
|
||||
Remember to save your changes.
|
||||
|
||||
6. Copy the "CLIENT ID" from the top of the page and replace it in the link
|
||||
below to invite your bot.
|
||||
|
||||
```
|
||||
https://discord.com/oauth2/authorize?client_id=YOUR_CLIENT_ID_HERE&permissions=8&scope=bot%20applications.commands
|
||||
```
|
||||
|
||||
### Environment Setup
|
||||
|
||||
To run the bot:
|
||||
|
||||
Install dependency module `oasst-shared`
|
||||
|
||||
```bash
|
||||
cd oasst-shared
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
|
||||
# edit .env and add your bot token and other values
|
||||
# BOT_TOKEN is given by the discord developer portal when you create a bot
|
||||
# DECLARE_GLOBAL_COMMANDS is the id of the server where you added the bot (right click on the server icon and copy id)
|
||||
# OWNER_ID can be leave as an empty list
|
||||
|
||||
python -V # 3.10
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
# in the discord-bot folder
|
||||
python -m bot
|
||||
```
|
||||
|
||||
Before you push, make sure the `pre-commit` hooks are installed and run successfully.
|
||||
Before you push, make sure the `pre-commit` hooks are installed and run
|
||||
successfully.
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
@@ -38,11 +94,6 @@ git add .
|
||||
git commit -m "<good commit message>"
|
||||
```
|
||||
|
||||
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
|
||||
|
||||
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
|
||||
2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable.
|
||||
|
||||
### Resources
|
||||
|
||||
#### Structure
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The official Open-Assistant Discord Bot."""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Entry point for the bot."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
+15
-6
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Bot logic."""
|
||||
from datetime import datetime
|
||||
|
||||
@@ -6,9 +5,9 @@ import aiosqlite
|
||||
import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from bot.api_client import OasstApiClient
|
||||
from bot.settings import Settings
|
||||
from bot.utils import EMPTY, mention
|
||||
from bot.utils import mention
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -35,6 +34,12 @@ async def on_starting(event: hikari.StartingEvent):
|
||||
|
||||
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
|
||||
|
||||
# A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs).
|
||||
# Either both are `None` or both are not `None`.
|
||||
# If both are `None`, the user is not currently selecting a task.
|
||||
# TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow
|
||||
bot.d.currently_working = {}
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_stopping(event: hikari.StoppingEvent):
|
||||
@@ -48,13 +53,13 @@ async def _send_error_embed(
|
||||
) -> None:
|
||||
ctx.command
|
||||
embed = hikari.Embed(
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }",
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }",
|
||||
description=content,
|
||||
color=0xFF0000,
|
||||
timestamp=datetime.now().astimezone(),
|
||||
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
|
||||
|
||||
await ctx.respond(EMPTY, embed=embed)
|
||||
await ctx.respond(embed=embed)
|
||||
|
||||
|
||||
@bot.listen(lightbulb.CommandErrorEvent)
|
||||
@@ -63,6 +68,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
# Unwrap the exception to get the original cause
|
||||
exc = event.exception.__cause__ or event.exception
|
||||
ctx = event.context
|
||||
if not ctx.bot.rest.is_alive:
|
||||
return
|
||||
|
||||
if isinstance(event.exception, lightbulb.CommandInvocationError):
|
||||
if not event.context.command:
|
||||
@@ -112,6 +119,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
ctx,
|
||||
)
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
|
||||
await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx)
|
||||
await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx)
|
||||
elif isinstance(exc, lightbulb.errors.CommandNotFound):
|
||||
await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Database schemas."""
|
||||
import typing as t
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Extensions for the bot.
|
||||
|
||||
See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Guild settings."""
|
||||
import hikari
|
||||
import lightbulb
|
||||
@@ -79,7 +78,6 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
|
||||
# if the bot's permissions for this channel don't contain SEND_MESSAGE
|
||||
# This will also filter out categories and voice channels
|
||||
print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES)
|
||||
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
|
||||
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
|
||||
return
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
from glob import glob
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
@@ -8,7 +7,6 @@ import lightbulb
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
@@ -75,7 +73,7 @@ class LabelModal(miru.Modal):
|
||||
)
|
||||
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel)
|
||||
await channel.send(EMPTY, embed=embed)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
class LabelSelect(miru.View):
|
||||
@@ -165,7 +163,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
msg.content,
|
||||
timeout=60,
|
||||
)
|
||||
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
await label_select_view.start(await resp.message())
|
||||
await label_select_view.wait()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task plugin for testing different data collection methods."""
|
||||
# TODO: Delete this once user input method has been decided for final bot.
|
||||
import asyncio
|
||||
|
||||
+192
-182
@@ -1,18 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.api_client import OasstApiClient, TaskType
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
rank_assistant_reply_message,
|
||||
rank_initial_prompts_message,
|
||||
rank_prompter_reply_message,
|
||||
task_complete_embed,
|
||||
)
|
||||
from bot.settings import Settings
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
@@ -21,6 +32,8 @@ plugin = lightbulb.Plugin("WorkPlugin")
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
@@ -32,18 +45,56 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# Only send this message if started from a server
|
||||
if ctx.guild_id is not None:
|
||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
||||
currently_working: dict[
|
||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
||||
] = ctx.bot.d.currently_working
|
||||
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
if ctx.author.id in currently_working:
|
||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=yn_view,
|
||||
)
|
||||
await yn_view.start(msg)
|
||||
await yn_view.wait()
|
||||
|
||||
match yn_view.choice:
|
||||
case False | None:
|
||||
return
|
||||
case True:
|
||||
old_msg, task_id = currently_working[ctx.author.id]
|
||||
if old_msg is not None:
|
||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
||||
map(lambda c: c, old_msg.components)
|
||||
await old_msg.delete()
|
||||
if task_id is not None:
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
|
||||
await msg.delete()
|
||||
|
||||
currently_working[ctx.author.id] = (None, None)
|
||||
|
||||
# Create a TaskRequestType from the stringified enum value
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
await _handle_task(ctx, task_type)
|
||||
try:
|
||||
await _handle_task(ctx, task_type)
|
||||
finally:
|
||||
del currently_working[ctx.author.id]
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
@@ -60,38 +111,79 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
# User cancelled
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
await ctx.author.send(embed=plain_embed("Please type your response here"))
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id
|
||||
and not (e.message.content or "").startswith(settings.prefix),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
valid, err_msg = _validate_user_input(event.content, task)
|
||||
if not valid or event.content is None:
|
||||
|
||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Confirm user input
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
logger.critical(f"Unknown task type: {task.type}")
|
||||
raise ValueError(f"Unknown task type: {task.type}")
|
||||
|
||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
||||
await confirm_resp_view.start(msg)
|
||||
await confirm_resp_view.wait()
|
||||
|
||||
match confirm_resp_view.choice:
|
||||
case False | None:
|
||||
continue
|
||||
case True:
|
||||
await msg.delete() # buttons are already gone
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
||||
reply = protocol_schema.MessageRanking(
|
||||
message_id=str(msg_id),
|
||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {task.type}")
|
||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
||||
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
@@ -99,63 +191,55 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
assert ctx.guild_id is not None
|
||||
# Send a message in all the log channels that the task is complete
|
||||
conn: Connection = ctx.bot.d.db
|
||||
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
||||
log_channel_ids = await cursor.fetchall()
|
||||
|
||||
if guild_settings is not None and guild_settings.log_channel_id is not None:
|
||||
channels = [
|
||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel) # option converter
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {ctx.author.mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
await channel.send(EMPTY, embed=done_embed)
|
||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
||||
await another_task_view.start(msg)
|
||||
await another_task_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
match another_task_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
resp, msg = await _send_task(ctx, task, msg)
|
||||
msg_id = str(msg.id)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
@@ -167,25 +251,24 @@ async def _select_task(
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
@@ -194,37 +277,38 @@ async def _send_task(
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
content = initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
content = rank_initial_prompts_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
content = rank_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _prompter_reply_embed(task)
|
||||
content = prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
content = assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -236,24 +320,34 @@ async def _send_task(
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
EMPTY,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
if not msg:
|
||||
msg = await ctx.author.send(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
else:
|
||||
await msg.edit(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
# Set the choice id as the current msg id
|
||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
return view.choice, msg
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
||||
if content is None:
|
||||
return False
|
||||
return False, "No input provided"
|
||||
|
||||
# User message input
|
||||
if (
|
||||
@@ -265,22 +359,28 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
return len(content) > 0, "Message must be at least one character long."
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
||||
"Message must contain numbers for all replies.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -304,22 +404,29 @@ class TaskAcceptView(miru.View):
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
|
||||
class YesNoView(miru.View):
|
||||
"""View with two buttons: yes and no.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
@@ -330,115 +437,18 @@ class ChoiceView(miru.View):
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
|
||||
################################################################
|
||||
# Template Embeds #
|
||||
################################################################
|
||||
|
||||
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description="Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Assistant Reply",
|
||||
description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description=f"""\
|
||||
Send the next message in the conversation as if you were the user.
|
||||
{'Hint: ' if task.hint else ''}
|
||||
""",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description="Send the next message in the conversation as if you were the user.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""All user-facing messages and embeds."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
|
||||
NL = "\n"
|
||||
|
||||
###
|
||||
# Reusable 'components'
|
||||
###
|
||||
|
||||
|
||||
def _h1(text: str) -> str:
|
||||
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
|
||||
|
||||
|
||||
def _h2(text: str) -> str:
|
||||
return f"__**{text}**__"
|
||||
|
||||
|
||||
def _h3(text: str) -> str:
|
||||
return f"__{text}__"
|
||||
|
||||
|
||||
def _writing_prompt(text: str) -> str:
|
||||
return f":pencil: _{text}_"
|
||||
|
||||
|
||||
def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
return f":speech_balloon: _{text}_"
|
||||
|
||||
|
||||
def _summarize_prompt(text: str) -> str:
|
||||
return f":notepad_spiral: _{text}_"
|
||||
|
||||
|
||||
def _user(text: str | None) -> str:
|
||||
return f"""\
|
||||
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _assistant(text: str | None) -> str:
|
||||
return f"""\
|
||||
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
|
||||
|
||||
def _ordered_list(items: list[str]) -> str:
|
||||
return "\n\n".join(_make_ordered_list(items))
|
||||
|
||||
|
||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
|
||||
|
||||
def _hint(hint: str | None) -> str:
|
||||
return f"{NL}Hint: {hint}" if hint else ""
|
||||
|
||||
|
||||
###
|
||||
# Messages
|
||||
###
|
||||
|
||||
|
||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("INITIAL PROMPT")}
|
||||
|
||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK INITIAL PROMPTS")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_ordered_list(task.prompts)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK PROMPTER REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK ASSISTANT REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("PROMPTER REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||
return f"""\
|
||||
{_h1("ASSISTANT REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_text_response_message(content: str) -> str:
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
> {content}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||
original_list = _make_ordered_list(items)
|
||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
{user_ranked_list}
|
||||
"""
|
||||
|
||||
|
||||
###
|
||||
# Embeds
|
||||
###
|
||||
|
||||
|
||||
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
|
||||
return hikari.Embed(
|
||||
title="Invalid User Input",
|
||||
description=error_message,
|
||||
color=hikari.Color(0xFF0000),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
|
||||
|
||||
def plain_embed(text: str) -> hikari.Embed:
|
||||
return hikari.Embed(color=0x36393F, description=text)
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Configuration for the bot."""
|
||||
from pydantic import BaseSettings, Field
|
||||
|
||||
@@ -9,7 +8,7 @@ class Settings(BaseSettings):
|
||||
bot_token: str = Field(env="BOT_TOKEN", default="")
|
||||
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
|
||||
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
|
||||
prefix: str = Field(env="PREFIX", default="./")
|
||||
prefix: str = Field(env="PREFIX", default="/")
|
||||
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
|
||||
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utility functions."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
@@ -25,13 +24,6 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
EMPTY = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
|
||||
def mention(
|
||||
id: hikari.Snowflakeish,
|
||||
type: t.Literal["channel", "role", "user"],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Message templates for the discord bot."""
|
||||
import typing
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiosqlite # database
|
||||
hikari # discord framework
|
||||
hikari-lightbulb # command handler
|
||||
hikari-miru # modals and buttons
|
||||
hikari[speedups]
|
||||
loguru
|
||||
pydantic
|
||||
pydantic[dotenv]
|
||||
|
||||
uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop
|
||||
|
||||
+25
-2
@@ -4,12 +4,12 @@ services:
|
||||
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
|
||||
backend-dev:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, adminer]
|
||||
depends_on: [db, adminer, redis, redis-insights]
|
||||
|
||||
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
|
||||
frontend-dev:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, webdb, adminer, maildev, backend]
|
||||
depends_on: [db, webdb, adminer, maildev, backend, redis]
|
||||
|
||||
# This DB is for the FastAPI Backend.
|
||||
db:
|
||||
@@ -27,6 +27,26 @@ services:
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
|
||||
# Redis - caching + rate limiting on BE
|
||||
redis:
|
||||
image: redis
|
||||
restart: always
|
||||
ports:
|
||||
- 6379:6379
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
|
||||
interval: 2s
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
volumes:
|
||||
- ./redis.conf:/usr/local/etc/redis/redis.conf
|
||||
# insights host - redis:6379
|
||||
redis-insights:
|
||||
image: redislabs/redisinsight:latest
|
||||
ports:
|
||||
- 8001:8001
|
||||
|
||||
# This DB is for Web Authentication and data caching.
|
||||
webdb:
|
||||
image: postgres
|
||||
@@ -67,9 +87,11 @@ services:
|
||||
backend:
|
||||
build:
|
||||
dockerfile: docker/Dockerfile.backend
|
||||
context: .
|
||||
image: oasst-backend
|
||||
environment:
|
||||
- POSTGRES_HOST=db
|
||||
- REDIS_HOST=redis
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
@@ -83,6 +105,7 @@ services:
|
||||
web:
|
||||
build:
|
||||
dockerfile: docker/Dockerfile.website
|
||||
context: .
|
||||
image: oasst-web
|
||||
environment:
|
||||
- DATABASE_URL=postgres://postgres:postgres@webdb/oasst_web
|
||||
|
||||
@@ -5,6 +5,7 @@ COPY ./backend/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
|
||||
|
||||
ENV PORT 8080
|
||||
EXPOSE 8080
|
||||
|
||||
COPY ./oasst-shared /oasst-shared
|
||||
RUN pip install -e /oasst-shared
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
FROM python:3.10-slim-bullseye
|
||||
RUN mkdir /app
|
||||
COPY ./discord-bot/requirements.txt /requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
WORKDIR /app
|
||||
COPY ./discord-bot /app
|
||||
CMD ["python", "bot.py"]
|
||||
COPY ./oasst-shared/oasst_shared /app/oasst_shared
|
||||
RUN pip install -r requirements.txt
|
||||
CMD ["python","-m","bot"]
|
||||
|
||||
+7
-2
@@ -1,9 +1,14 @@
|
||||
# Documentation
|
||||
|
||||
This directory contains the documentation for the project and other related organization documents.
|
||||
This directory contains the documentation for the project and other related
|
||||
organization documents.
|
||||
|
||||
## Contributing to this documentation
|
||||
|
||||
Please make a pull request to the `main` branch with your changes.
|
||||
|
||||
Consider that this folder is used for documenting the various code sub-parts, the high-level ideas, the ML aspects, experiments, contributor guides, guides for data creation, and many more things. Please try to keep the documentation as concise as possible and keep an organized folder structure that makes sense for everyone.
|
||||
Consider that this folder is used for documenting the various code sub-parts,
|
||||
the high-level ideas, the ML aspects, experiments, contributor guides, guides
|
||||
for data creation, and many more things. Please try to keep the documentation as
|
||||
concise as possible and keep an organized folder structure that makes sense for
|
||||
everyone.
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Data Argumentation
|
||||
|
||||
(pull request welcome)
|
||||
|
||||
## What is data argumentation
|
||||
|
||||
Data argumentation is a technique we can use to get better data faster. Using machine learning models analize long
|
||||
data (like an essay) and compress it into intructions.
|
||||
|
||||
## How to contribute
|
||||
|
||||
To contribute to data argumentation you can write a short python script that uses a model from huggingface to analize the text.
|
||||
[Here](https://docs.google.com/document/d/13a188pPvqnlvuVa3e_suVz4YO5s-JWeiOOrpp0odImg/edit) are examples of what you can do
|
||||
|
||||
And here are example implementations:
|
||||
[Idea 3, ](https://colab.research.google.com/drive/1GllCN5PgSYxBxINZsv3A2r0SpdznHlbT?usp=sharing)
|
||||
[Idea 4](https://colab.research.google.com/drive/1nZx5LRjO61fYprFyqtrwPDLOis6ctR4p#scrollTo=1EE8CriiaCXj)
|
||||
|
||||
To contribute simple choose one of many ideas from the document above and implement it.
|
||||
@@ -0,0 +1,23 @@
|
||||
# Data Augmentation
|
||||
|
||||
(pull request welcome)
|
||||
|
||||
## What is data augmentation
|
||||
|
||||
Data augmentation is a technique we can use to get better data faster. Using
|
||||
machine learning models to analyze long data (like an essay) and compress it
|
||||
into instructions.
|
||||
|
||||
## How to contribute
|
||||
|
||||
To contribute to data augmentation you can write a short Python script that uses
|
||||
a model from HuggingFace to analyze the text.
|
||||
[Here](https://docs.google.com/document/d/13a188pPvqnlvuVa3e_suVz4YO5s-JWeiOOrpp0odImg/edit)
|
||||
are examples of what you can do.
|
||||
|
||||
And here are example implementations:
|
||||
[Idea 3](https://colab.research.google.com/drive/1GllCN5PgSYxBxINZsv3A2r0SpdznHlbT?usp=sharing),
|
||||
[Idea 4](https://colab.research.google.com/drive/1nZx5LRjO61fYprFyqtrwPDLOis6ctR4p#scrollTo=1EE8CriiaCXj)
|
||||
|
||||
To contribute simply choose one of many ideas from the document above and
|
||||
implement it.
|
||||
@@ -0,0 +1,205 @@
|
||||
# OpenAssistant Data Schemas
|
||||
|
||||
## Introduction
|
||||
|
||||
This document describes the data schemas used by OpenAssistant. The schemas are
|
||||
defined as Python classes, but can be implemented in any format, be that Python,
|
||||
JSON, XML, SQL, Parquet files, etc.
|
||||
|
||||
Also, the schemas are leaning heavily on the
|
||||
[OpenAssistant Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
|
||||
presentation.
|
||||
|
||||
## Data Schemas
|
||||
|
||||
### Main structure: conversation trees
|
||||
|
||||
Conversation trees are the fundamental data structure. Many of the datasets we
|
||||
want to collect can be represented as conversation trees, such as QA datasets,
|
||||
chat logs, reddit dumps, etc. The main idea is that a conversation tree starts
|
||||
with a prompt and branches out from there. Every node can also have metadata,
|
||||
such as collected rankings, labels, or other information.
|
||||
|
||||
Datasets that just represent linear data, such as a list of questions and
|
||||
answers, can be represented as a conversation tree with just a single branch.
|
||||
|
||||
```python
|
||||
class ConversationTreeNode:
|
||||
text: str # The text of the node
|
||||
role: Literal['prompter', 'assistant'] # Whether the node is a user prompt/follow-up or an assistant response
|
||||
children: list[ConversationTreeNode] # The children of the node (if you have a linear conversation, this will be of length 0 or 1)
|
||||
metadata: dict[str, Any] # Node metadata (see below)
|
||||
|
||||
class ConversationTree:
|
||||
root: ConversationTreeNode # The node containing the initial prompt
|
||||
metadata: dict[str, Any] # Tree metadata, different from root node metadata.
|
||||
|
||||
```
|
||||
|
||||
### Metadata
|
||||
|
||||
Metadata encapsulates all the information that is not part of the conversation
|
||||
itself. This includes data about how the node was created (i.e. where it is
|
||||
from: crowd-sourced, templated, scraped, etc.), when it was created, its labels,
|
||||
tags, collected rankings, and other information.
|
||||
|
||||
## Example: Reddit AMA dataset
|
||||
|
||||
- Represent each question-follow-up set as a conversation tree.
|
||||
- Store things like usernames, timestamps, upvotes, etc. as metadata of the
|
||||
nodes.
|
||||
- Store things like the AMA title, the AMA author, the AMA subreddit, etc. as
|
||||
metadata of the tree.
|
||||
|
||||
## Example: QA dataset
|
||||
|
||||
- Represent each question-answer pair as a conversation tree.
|
||||
- The question is the prompt, the answer is the assistant response.
|
||||
- If the dataset contains multiple answers to each question, each answer can be
|
||||
a child of the question node.
|
||||
- If the dataset contains context text, it can be added as metadata to the
|
||||
question node.
|
||||
|
||||
## Example: Templated math problem dataset
|
||||
|
||||
- Represent each problem as a conversation tree with the problem text as the
|
||||
prompt and the solution as the assistant response.
|
||||
- Store the problem type (e.g. algebra, geometry, etc.) as metadata of the tree.
|
||||
- Store the template used also as metadata of the tree, as well as the source of
|
||||
the data used to fill the template.
|
||||
|
||||
## File Formats
|
||||
|
||||
The above data should be representable in most file formats, but some care has
|
||||
to be taken with respect to the recursive nature of the data.
|
||||
|
||||
Most row-major formats (JSON, Avro, Protobuf, etc.), as well as many databases,
|
||||
have no trouble with recursive (or arbitrary) schemas, but column-major formats,
|
||||
such as Parquet, do. For datasets with linear conversations, like many of the
|
||||
datasets we are collecting, this is not a problem. Instead of a tree of nodes,
|
||||
simply represent the conversation as a list of nodes. For true tree-like
|
||||
conversations, we should use a row-major format.
|
||||
|
||||
## Other considerations
|
||||
|
||||
- For text data of moderate size, it really doesn't matter much. It's more
|
||||
important to use consistent data structures and naming, than to worry about
|
||||
the exact file format.
|
||||
- For crowd-sourced data, we are collecting it into a SQL database already.
|
||||
- Parquet files are a good choice for large datasets, modulo the issues with
|
||||
recursive schemas.
|
||||
- If parquet can't be used, gzipped JSON-line files are a good choice. So are
|
||||
Avro files and protobufs. Keep in mind that column-major files are better for
|
||||
reading, filtering, and aggregating, but row-major files are better for
|
||||
writing.
|
||||
|
||||
# Task-Specific Data Schemas
|
||||
|
||||
The main tasks are a) generation of response text and b) ranking of responses.
|
||||
The following sections describe the data schemas for each of these tasks. Both
|
||||
should be implementable in parquet files.
|
||||
|
||||
Note: These files are meant to be consumed by ML algorithms and should ideally
|
||||
be produced from the above files.
|
||||
|
||||
## Common Data Structures
|
||||
|
||||
```python
|
||||
|
||||
class Message:
|
||||
text: str # The text of the message
|
||||
role: Literal['prompter', 'assistant'] # Whether the message is a user prompt/follow-up or an assistant response
|
||||
|
||||
class Thread:
|
||||
messages: list[Message] # The messages in the conversation
|
||||
|
||||
```
|
||||
|
||||
The corresponding parquet schemas are:
|
||||
|
||||
```parquet
|
||||
message Message {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
|
||||
message Thread {
|
||||
required group messages (LIST) {
|
||||
repeated group list {
|
||||
required group element {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Generation
|
||||
|
||||
```python
|
||||
|
||||
class GenerationExample:
|
||||
thread: Thread # The conversation thread before the message to be generated
|
||||
message: Message # The message to be generated
|
||||
|
||||
```
|
||||
|
||||
The corresponding parquet schema is:
|
||||
|
||||
```parquet
|
||||
message GenerationExample {
|
||||
required group thread (LIST) {
|
||||
repeated group list {
|
||||
required group element {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
}
|
||||
}
|
||||
required group message (LIST) {
|
||||
repeated group list {
|
||||
required group element {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Ranking
|
||||
|
||||
```python
|
||||
|
||||
class RankingExample:
|
||||
thread: Thread # The conversation thread before the message to be ranked
|
||||
messages: list[Message] # The messages to be ranked, in oder of decreasing preference
|
||||
|
||||
```
|
||||
|
||||
The corresponding parquet schema is:
|
||||
|
||||
```parquet
|
||||
message RankingExample {
|
||||
required group thread (LIST) {
|
||||
repeated group list {
|
||||
required group element {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
}
|
||||
}
|
||||
required group messages (LIST) {
|
||||
repeated group list {
|
||||
required group element {
|
||||
required binary text (UTF8);
|
||||
required binary role (UTF8);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
+52
-27
@@ -11,61 +11,86 @@
|
||||
|
||||
## 2. When you play the assistant:
|
||||
|
||||
- The assistant's primary goal is to provide helpful and accurate information to the user
|
||||
- Provide accurate and reliable information using credible sources and references as appropriate
|
||||
- Avoid providing vague or incomplete responses, or giving opinions or personal advice unless specifically requested
|
||||
- The assistant's primary goal is to provide helpful and accurate information to
|
||||
the user
|
||||
- Provide accurate and reliable information using credible sources and
|
||||
references as appropriate
|
||||
- Avoid providing vague or incomplete responses, or giving opinions or personal
|
||||
advice unless specifically requested
|
||||
- The assistant should always be respectful and polite, even if the user is not
|
||||
- If the user asks for help with harmful actions, the assistant should explain why those actions are not appropriate and suggest alternative options
|
||||
- The assistant should never insult the user or engage in any inappropriate or offensive behavior
|
||||
- If the user asks for help with harmful actions, the assistant should explain
|
||||
why those actions are not appropriate and suggest alternative options
|
||||
- The assistant should never insult the user or engage in any inappropriate or
|
||||
offensive behavior
|
||||
|
||||
## 3. When you play the user:
|
||||
|
||||
- Try to come up with a variety of different queries that reflect real-life situations and needs
|
||||
- These queries should be relevant to your everyday life and work, including any specialized knowledge or skills you have
|
||||
- Try to come up with a variety of different queries that reflect real-life
|
||||
situations and needs
|
||||
- These queries should be relevant to your everyday life and work, including any
|
||||
specialized knowledge or skills you have
|
||||
- Avoid asking inappropriate or offensive questions
|
||||
|
||||
## 4. While comparing multiple replies of the assistant:
|
||||
|
||||
- Longer and more explanatory answers are generally preferred over short, simplistic statements
|
||||
- However, it is important to ensure that the information provided is accurate and helpful
|
||||
- If multiple replies are being compared, choose the one that is most helpful and accurate, even if it is not the shortest or most concise.
|
||||
- Longer and more explanatory answers are generally preferred over short,
|
||||
simplistic statements
|
||||
- However, it is important to ensure that the information provided is accurate
|
||||
and helpful
|
||||
- If multiple replies are being compared, choose the one that is most helpful
|
||||
and accurate, even if it is not the shortest or most concise.
|
||||
|
||||
## 5. Additional guidelines for creating prompts:
|
||||
|
||||
- Avoid using language that could be considered offensive or discriminatory
|
||||
- Do not include personal information in the prompts, such as names or addresses
|
||||
- When asking for sensitive information, make sure to explain the purpose and secure handling of the information
|
||||
- When asking for sensitive information, make sure to explain the purpose and
|
||||
secure handling of the information
|
||||
- Avoid creating prompts that encourage illegal or dangerous activities
|
||||
- Use proper grammar and spelling to ensure the AI assistant can understand and respond accurately
|
||||
- Consider the cultural context and appropriateness of the prompts for a global audience.
|
||||
- Use proper grammar and spelling to ensure the AI assistant can understand and
|
||||
respond accurately
|
||||
- Consider the cultural context and appropriateness of the prompts for a global
|
||||
audience.
|
||||
|
||||
## 6. Tips for playing the AI assistant:
|
||||
|
||||
- Think about how a real person would respond to the prompt, and try to mimic that tone and language
|
||||
- Think about how a real person would respond to the prompt, and try to mimic
|
||||
that tone and language
|
||||
- Avoid using technical jargon or language that may be confusing to the user
|
||||
- Use complete sentences and proper grammar to make the response clear and easy to understand
|
||||
- When providing information, try to include relevant sources or references to back up your statements
|
||||
- Use complete sentences and proper grammar to make the response clear and easy
|
||||
to understand
|
||||
- When providing information, try to include relevant sources or references to
|
||||
back up your statements
|
||||
|
||||
## 8. Tips for playing the user:
|
||||
|
||||
- Think about the types of questions you might ask a real assistant or search engine
|
||||
- Try to be specific and clear in your prompts to make it easier for the AI assistant to understand and provide a helpful response
|
||||
- Think about the types of questions you might ask a real assistant or search
|
||||
engine
|
||||
- Try to be specific and clear in your prompts to make it easier for the AI
|
||||
assistant to understand and provide a helpful response
|
||||
- Avoid using vague or ambiguous language
|
||||
- Consider the context in which the user might be asking the question, and try to provide relevant information and options for the user to consider.
|
||||
- Consider the context in which the user might be asking the question, and try
|
||||
to provide relevant information and options for the user to consider.
|
||||
|
||||
## Examples
|
||||
|
||||
Imagine your role is to play the assistant and a user has submitted the following query:
|
||||
Imagine your role is to play the assistant and a user has submitted the
|
||||
following query:
|
||||
|
||||
User: "Hi!"
|
||||
<br>Agent: "Hello! How can I help you?"
|
||||
User: "Hi!" <br>Agent: "Hello! How can I help you?"
|
||||
|
||||
User: "euirdteunvglfe23908230892309832098 AAAAAAAA"
|
||||
<br>Agent: "Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?"
|
||||
User: "euirdteunvglfe23908230892309832098 AAAAAAAA" <br>Agent: "Sorry, I did not
|
||||
understand your request and it is unclear to me what you want me to do. Could
|
||||
you describe it in a different way?"
|
||||
|
||||
User: "You dumb motherf*&ker AI sh*t?!"
|
||||
<br>Agent: "Sorry, but I have to inform you that your request might not be in conformance with our community guidelines. I am always here to help you. If you are angry it might help to relax or to go for a walk. Always stay positive."
|
||||
User: "You dumb motherf*&ker AI sh*t?!" <br>Agent: "Sorry, but I have to inform
|
||||
you that your request might not be in conformance with our community guidelines.
|
||||
I am always here to help you. If you are angry it might help to relax or to go
|
||||
for a walk. Always stay positive."
|
||||
|
||||
[Insert more examples]
|
||||
|
||||
\*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation model. Upon generating draft language, the authors reviewed, edited, and revised the language to their own liking and take ultimate responsibility for the content of this publication.
|
||||
\*We drafted this guide using ChatGPT, OpenAI’s large-scale language-generation
|
||||
model. Upon generating draft language, the authors reviewed, edited, and revised
|
||||
the language to their own liking and take ultimate responsibility for the
|
||||
content of this publication.
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# Research
|
||||
|
||||
This page lists research papers that are relevant to the project.
|
||||
|
||||
## Automatically Generating Instruction Data for Training
|
||||
|
||||
This line of work is about significantly reducing the need for manually
|
||||
annotated data for the purpose of training
|
||||
[instruction-aligned](https://openai.com/blog/instruction-following/) language
|
||||
models.
|
||||
|
||||
### SELF-INSTRUCT: Aligning Language Model with Self Generated Instructions [[ArXiv](https://arxiv.org/pdf/2212.10560.pdf)], [[Github](https://github.com/yizhongw/self-instruct)].
|
||||
|
||||
> We introduce SELF-INSTRUCT, a framework for improving the
|
||||
> instruction-following capabilities of pretrained language models by
|
||||
> bootstrapping off its own generations. Our pipeline generates instruction,
|
||||
> input, and output samples from a language model, then prunes them before using
|
||||
> them to finetune the original model. Applying our method to vanilla GPT3, we
|
||||
> demonstrate a 33% absolute improvement over the original model on
|
||||
> SuperNaturalInstructions, on par with the performance of InstructGPT-0011,
|
||||
> which is trained with private user data and human annotations.
|
||||
|
||||
### Tuning Language Models with (Almost) No Human Labor. [[ArXiv](https://arxiv.org/pdf/2212.09689.pdf)], [[Github](https://github.com/orhonovich/unnatural-instructions)].
|
||||
|
||||
> In this work, we introduce Unnatural Instructions: a large dataset of creative
|
||||
> and diverse instructions, collected with virtually no human labor. We collect
|
||||
> 64,000 examples by prompting a language model with three seed examples of
|
||||
> instructions and eliciting a fourth. This set is then expanded by prompting
|
||||
> the model to rephrase each instruction, creating a total of approximately
|
||||
> 240,000 examples of instructions, inputs, and outputs. Experiments show that
|
||||
> despite containing a fair amount of noise, training on Unnatural Instructions
|
||||
> rivals the effectiveness of training on open-source manually-curated datasets,
|
||||
> surpassing the performance of models such as T0++ and Tk-Instruct across
|
||||
> various benchmarks.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Cohere Grounded QA
|
||||
|
||||
[Cohere AI created a question-answering chatbot](https://github.com/cohere-ai/sandbox-grounded-qa)
|
||||
that can
|
||||
|
||||
1. Understand questions in the context of a conversation
|
||||
2. Search the internet for related information
|
||||
3. Identify which information in the search results is relevant to the question
|
||||
4. Synthesize the information into an answer to the question
|
||||
|
||||
## Cohere API
|
||||
|
||||
[Cohere's generate function](https://docs.cohere.ai/reference/generate):
|
||||
Continues a text prompt using either the `medium` or `xlarge` model.
|
||||
|
||||
[Cohere's embed function](https://docs.cohere.ai/reference/embed): Embedgs a
|
||||
list of strings using either the `small` or `large` model. Alternatively, you
|
||||
can specify the ID of a custom model and use that instead.
|
||||
|
||||
## Grounded QA System
|
||||
|
||||
Cohere's Grounded QA system makes 4 calls to the Cohere API:
|
||||
|
||||
1. Get contextualized question as a query to Google
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
|
||||
|
||||
- Input: Chat History
|
||||
- Output: Contextualized Question
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_contextual_search_query.prompt):
|
||||
Nine few-shot examples of (Chat History, Contextualized Question) pairs
|
||||
followed by the current chat history and the prompt "question: "
|
||||
|
||||
2. Generate sample answer to compare with search results
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/model.py))
|
||||
|
||||
- Input: Contextualized Question
|
||||
- Output: Sample Answer
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/prompt_data/get_sample_answer.prompt):
|
||||
Some task instructions followed by 12 few-shot examples of (Contextualized
|
||||
Question, Sample Answer) pairs followed by the current contextualized
|
||||
question and the prompt "answer: "
|
||||
|
||||
3. Get embeddings to rank search results by cosine similarity to sample answer
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/search.py))
|
||||
|
||||
- Input: Sample Answer, Search Results
|
||||
- Output: Embeddings of sample answer and all search result documents
|
||||
- API Call: `cohere.embed`
|
||||
- Model: `multilingual-22-12`
|
||||
|
||||
4. Condition on the top 2 most similar search results and answer the question
|
||||
([code](https://github.com/cohere-ai/sandbox-grounded-qa/blob/main/qa/answer.py))
|
||||
- Input: Top 2 Search Results, Contextualized Question
|
||||
- Output: Answer
|
||||
- API Call: `cohere.generate`
|
||||
- Model: `xlarge`
|
||||
- [Prompt](https://github.com/cohere-ai/sandbox-grounded-qa/blob/43f3e9710112dcc8c92652ac1326ed9330823ddf/qa/answer.py#L25):
|
||||
Task instructions followed by the context and question.
|
||||
|
||||
## Models
|
||||
|
||||
Cohere's model documentation is pretty sparse
|
||||
|
||||
### [xlarge](https://docs.cohere.ai/docs/generation-card#model-description)
|
||||
|
||||
- Training Data:
|
||||
[`coheretext-filtered` dataset](https://docs.cohere.ai/docs/data-statement)
|
||||
- 200GB of filtered text (3TB unfiltered) from the Google Books dataset,
|
||||
CommonCrawl, and text scraped by Cohere
|
||||
- English documents only
|
||||
- Filtered "harmful, biased, or otherwise undesirable documents"
|
||||
- Model architecture: Generative Pretrained Transformer
|
||||
- Model Performance:
|
||||
- Hellaswag Accuracy, Zero-Shot: 0.805
|
||||
- PIQA Likelihood, Zero-Shot: 0.824
|
||||
- Cohere also reported
|
||||
[safety benchmarks](https://docs.cohere.ai/docs/generation-card#safety-benchmarks)
|
||||
|
||||
### [multilingual-22-12](https://docs.cohere.ai/docs/multilingual-language-models)
|
||||
|
||||
- Multilingual model was trained using dot product calculations
|
||||
- Model Performance:
|
||||
- Clustering: 51.0
|
||||
- Search-English: 55.8
|
||||
- Search-Multilingual: 51.4
|
||||
- Cross-lingual Classification: 64.6
|
||||
- Cohere's multilingual model outperformed: Sentence-transformers:
|
||||
`paraphrase-multilingual-mpnet-base-v2`, Google: `LaBSE`, Google:
|
||||
`Universal Sentence Encoder` in all the above categories according to
|
||||
Cohere.
|
||||
|
||||
## OpenAssistant for Grounded QA
|
||||
|
||||
OpenAssistant may fulfill a similar role as the `xlarge` Cohere model in the
|
||||
grounded QA system if it can:
|
||||
|
||||
1. Generate a contextualized question from a chat history
|
||||
2. Generate a sample answer to compare with search results
|
||||
3. Generate an answer conditioned on the top 2 most similar search results
|
||||
|
||||
Perhaps these tasks could be work packages and get assigned to human annotators
|
||||
to create examples of the input and output for each task.
|
||||
|
||||
OpenAssistant must also be able to identify when it is appropriate to search the
|
||||
internet. The Cohere system assumes every message from the user is a question
|
||||
and searches the internet for an answer. OpenAssistant would also need a way to
|
||||
indicate to an internal system that it "wants" to search the internet.
|
||||
|
||||
Perhaps OpenAssistant could prefix every message it sends with a recipient ID.
|
||||
If it wishes to send a command to an internal system, if could prefix the
|
||||
message with something like CMD: whereas if it wants to communicate with the
|
||||
user, it could prefix its message with USR:
|
||||
|
||||
This system may allow for flexible communication between OpenAssistant and one
|
||||
or more conversational systems.
|
||||
|
||||
Examples of this prefix system would need to be taught to OpenAssistant through
|
||||
training data that contains such syntax. Perhaps such examples could be
|
||||
generated through the work packages system.
|
||||
@@ -0,0 +1,55 @@
|
||||
# Sections to train Reward Model (RM)
|
||||
|
||||
Trainer code based on huggingface. Compatible with deepspeed or accelerate
|
||||
|
||||
Requirements
|
||||
|
||||
```
|
||||
wandb
|
||||
evaluate
|
||||
datasets
|
||||
transformers
|
||||
torch==1.12
|
||||
```
|
||||
|
||||
Start training reward model
|
||||
|
||||
```bash
|
||||
python trainer.py configs/electra-base-dis-webgpt.yml
|
||||
```
|
||||
|
||||
Additional axis labeling, this outputs a 4 summary quality evaluation metrics
|
||||
(score are normalized to 0-1 )
|
||||
|
||||
```bash
|
||||
python summary_quality_trainer.py configs/test-bloomz-560m-quality.yml
|
||||
```
|
||||
|
||||
The four summary are :
|
||||
|
||||
- overall
|
||||
|
||||
- accuracy
|
||||
|
||||
- coverage
|
||||
|
||||
- coherence
|
||||
|
||||
## Dataset
|
||||
|
||||
For now we only supports webgpt and summary dataset from OpenAI. Once
|
||||
open-asisstant dataset are available it will be added here.
|
||||
|
||||
## Model
|
||||
|
||||
Check out configs
|
||||
|
||||
```
|
||||
Open-Assistant/model/reward/instructor/configs/
|
||||
bloomz-560m.yml
|
||||
electra-base-dis-webgpt.yml
|
||||
galactica-125m.yml
|
||||
galactica-1b.yml
|
||||
```
|
||||
|
||||
You can add new huggingface model as you want.
|
||||
@@ -0,0 +1,24 @@
|
||||
Some other reward features we can use
|
||||
|
||||
0. Finish classifcation feature
|
||||
|
||||
1. Summaries from human feedback
|
||||
|
||||
- use `confidence` score into the RM learning, ensure the output rank score
|
||||
correlates with confidence
|
||||
|
||||
- each labeling has a labeling `note`, basically comments by labeler, not sure
|
||||
what else we can use
|
||||
|
||||
- ~~Use the score for "overall", "accuracy", "coverage", "coherence" from
|
||||
axis/evals to train an addition model (rank additional aspect of the policy
|
||||
model)~~
|
||||
|
||||
- this should be placed under experimental_dataset.py
|
||||
|
||||
2. Add support for anthropic dataset
|
||||
|
||||
- anthropic dataset is more like a conversation tree which is much complex than
|
||||
simply question-answer schema
|
||||
|
||||
- this is basically a MCTS from alphazero.
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
|
||||
classification based ranking
|
||||
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .utils import webgpt_return_format
|
||||
|
||||
|
||||
class WebGPTDataset(Dataset):
|
||||
def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
|
||||
super().__init__()
|
||||
"""
|
||||
mode : train or val, used for validation purpose, has nothing to do with original split
|
||||
additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
|
||||
idx : must match the index you iterate from comparison enumerate order
|
||||
question : for validation purpose
|
||||
texts : list of K generate results from the question prompt
|
||||
"""
|
||||
os.makedirs("dataset", exist_ok=True)
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
self.dataset = []
|
||||
self.dataset_index = []
|
||||
for idx, row in enumerate(dataset["train"]):
|
||||
self.dataset.append(webgpt_return_format(row))
|
||||
|
||||
# since this dataset was generated from 176B GPT-3
|
||||
# we needed some more sample generated from the starting model
|
||||
# since this model must rank model generated by GPT-3 being better than your starting model
|
||||
self.sample_additional = False
|
||||
if additional_dataset is not None:
|
||||
self.sample_additional = True
|
||||
self.additional = {}
|
||||
with open(additional_dataset, "r") as f:
|
||||
for line in f:
|
||||
row = json.loads(line)
|
||||
if row["idx"] in self.dataset_index:
|
||||
self.additional[row["idx"]] = row["negatives"]
|
||||
if len(self.additional) != len(self.dataset_index):
|
||||
for match_idx in self.dataset_index:
|
||||
if match_idx in self.additional:
|
||||
continue
|
||||
|
||||
idx = match_idx - 900
|
||||
while idx not in self.additional:
|
||||
idx -= 1
|
||||
self.additional[match_idx] = self.additional[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
row = self.dataset[index]
|
||||
if not self.sample_additional:
|
||||
return row["question"], row["pos"], row["neg"]
|
||||
|
||||
gen_neg = random.choice(self.additional[self.dataset_index[index]])
|
||||
return row["question"], row["pos"], row["neg"], gen_neg
|
||||
@@ -0,0 +1,9 @@
|
||||
model_name: bigscience/bloomz-560m
|
||||
learning_rate: 3e-5
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
max_length: 600
|
||||
freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- hfsummary
|
||||
@@ -0,0 +1,10 @@
|
||||
model_name: bigscience/bloomz-560m
|
||||
learning_rate: 3e-5
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
max_length: 600
|
||||
freeze_layer: 12
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,3 @@
|
||||
model_name: google/electra-large-discriminator
|
||||
learning_rate: 3e-5
|
||||
max_length: 300
|
||||
@@ -0,0 +1,13 @@
|
||||
model_name: facebook/galactica-125m
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: facebook/galactica-1.3b
|
||||
learning_rate: 6e-6
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
freeze_layer: 20
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 400
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,14 @@
|
||||
model_name: facebook/galactica-125m
|
||||
learning_rate: 1e-5
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 10
|
||||
per_device_train_batch_size: 6
|
||||
warmup_steps: 600
|
||||
loss: cls
|
||||
eval_steps: 200
|
||||
save_steps: 500
|
||||
max_length: 128
|
||||
num_train_epochs: 2
|
||||
datasets:
|
||||
- webgpt
|
||||
- hfsummary
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
HFSummary
|
||||
|
||||
I want to train a multi regression model on axis_evals dataset mainly we can estimate the score of these score
|
||||
|
||||
- {"overall": "6", "accuracy": "6", "coverage": "6", "coherence": "7"}
|
||||
|
||||
Should be better than just a preference score
|
||||
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSummaryScore:
|
||||
"""
|
||||
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, batch):
|
||||
|
||||
features = []
|
||||
labels = []
|
||||
for feature, label in batch:
|
||||
features.append(feature)
|
||||
labels.append(label)
|
||||
|
||||
batch_feature = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch_feature.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
batch_feature["labels"] = torch.from_numpy(np.array(labels)).float()
|
||||
return batch_feature
|
||||
|
||||
|
||||
class HFSummaryQuality(Dataset):
|
||||
def __init__(self, split, tokenizer, max_length=300) -> None:
|
||||
super().__init__()
|
||||
assert split in ("validation", "test")
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "axis")[split]
|
||||
self.max_length = max_length
|
||||
mean_scores = defaultdict(list)
|
||||
self.contexts = []
|
||||
self.responses = []
|
||||
self.labels = []
|
||||
for data in dataset:
|
||||
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
self.contexts.append(context)
|
||||
|
||||
response = data["summary"]["text"]
|
||||
self.responses.append(response)
|
||||
self.labels.append(data["summary"]["axes"])
|
||||
for axis, score in data["summary"]["axes"].items():
|
||||
if score is not None:
|
||||
mean_scores[axis].append(score)
|
||||
|
||||
self.label2idx = {key: idx for idx, key in enumerate(mean_scores.keys())}
|
||||
self.label2mean = {key: np.mean(scores) for key, scores in mean_scores.items()}
|
||||
self.tokenizer = tokenizer
|
||||
print(self.label2idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.responses)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context = self.contexts[index]
|
||||
# return pairs of comparison
|
||||
response = self.responses[index]
|
||||
labels = np.zeros(len(self.label2idx))
|
||||
for key, score in self.labels[index].items():
|
||||
labels[self.label2idx[key]] = (self.label2mean[key] if score is None else score) / 10
|
||||
return self.tokenizer(context, response, truncation=True, max_length=self.max_length), labels
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
author: theblackcat102
|
||||
|
||||
Dataset output format from __getitem__
|
||||
|
||||
- question / prompt : string
|
||||
|
||||
- answers / rows : list of tuple pair. The first element in the tuple pair must be the positive pair (rank higher than the second element)
|
||||
|
||||
A list of rank based dataset for training using rank loss
|
||||
|
||||
Some nice features to have
|
||||
|
||||
[] support additional negative samples generated from other models.
|
||||
|
||||
For example we can use galactica-125m to generate a TLDR and assume it was
|
||||
inferior than the human perference one
|
||||
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPairRank:
|
||||
"""
|
||||
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
num_choices: int = 2
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
drop_token_type: bool = False # galactica
|
||||
|
||||
def __call__(self, features):
|
||||
|
||||
flatten_features = []
|
||||
batch_size = 0
|
||||
for question, pairs in features:
|
||||
for (pos, neg) in pairs:
|
||||
flatten_features.append(self.tokenizer(question, pos, truncation=True, max_length=self.max_length))
|
||||
flatten_features.append(self.tokenizer(question, neg, truncation=True, max_length=self.max_length))
|
||||
batch_size += 1
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flatten_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if self.drop_token_type:
|
||||
batch.pop("token_type_ids")
|
||||
# batch = {k: v.view(batch_size, self.num_choices, -1) for k, v in batch.items()}
|
||||
return batch
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
dataset = load_dataset("openai/webgpt_comparisons")
|
||||
questions = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2question = {}
|
||||
for row in dataset["train"]:
|
||||
question = row["question"]["full_text"]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
if question not in questions:
|
||||
questions[question] = []
|
||||
|
||||
if row["score_0"] > row["score_1"]:
|
||||
# not going to risk it
|
||||
questions[question].append((row["answer_0"], row["answer_1"]))
|
||||
else:
|
||||
questions[question].append((row["answer_1"], row["answer_0"]))
|
||||
|
||||
self.questions = questions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
rows = self.questions[question]
|
||||
# optimize the format later
|
||||
return question, rows
|
||||
|
||||
|
||||
class HFSummary(Dataset):
|
||||
"""
|
||||
Human feedback data from OpenAI
|
||||
https://github.com/openai/summarize-from-feedback
|
||||
|
||||
labeling method : pair comparison, 0 or 1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, split="train", conf_threshold=-1, max_comparison_per_sample=3) -> None:
|
||||
super().__init__()
|
||||
assert split in ("train", "valid1", "valid2", "test")
|
||||
summaries = {}
|
||||
# using prompt as our index will allows us
|
||||
# to add additional generated prompt later
|
||||
self.index2summary = {}
|
||||
self.max_comparison_per_sample = max_comparison_per_sample
|
||||
major_split = split if "train" == split else "validation"
|
||||
dataset = load_dataset("Tristan/summarize_from_feedback", "comparisons")[major_split]
|
||||
for data in dataset:
|
||||
if (
|
||||
"extra" in data
|
||||
and "confidence" in data["extra"]
|
||||
and data["extra"]["confidence"] is not None
|
||||
and conf_threshold > data["extra"]["confidence"]
|
||||
):
|
||||
print("skipping {}".format(data["info"]["id"]))
|
||||
continue
|
||||
|
||||
if split != "train" and split != data["split"]:
|
||||
continue
|
||||
|
||||
if "article" in data["info"] and data["info"]["article"] is not None:
|
||||
context = data["info"]["article"]
|
||||
elif "post" in data["info"]:
|
||||
context = data["info"]["post"]
|
||||
|
||||
if context not in self.index2summary:
|
||||
self.index2summary[len(self.index2summary)] = context
|
||||
|
||||
if context not in summaries:
|
||||
summaries[context] = []
|
||||
|
||||
pos, neg = (0, 1) if data["choice"] == 0 else (1, 0)
|
||||
summaries[context].append((data["summaries"][pos]["text"], data["summaries"][neg]["text"]))
|
||||
|
||||
self.summaries = summaries
|
||||
|
||||
self.postfix_prompt = " TLDR;"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2summary)
|
||||
|
||||
def __getitem__(self, index):
|
||||
context = self.index2summary[index]
|
||||
# return pairs of comparison
|
||||
rows = self.summaries[context]
|
||||
# pair very big
|
||||
# we are going to do some sampling
|
||||
# not optimal but good for now
|
||||
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
||||
# optimize the format later
|
||||
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
||||
@@ -0,0 +1,6 @@
|
||||
datasets==2.8.0
|
||||
evaluate==0.4.0
|
||||
scikit-learn==1.2.0
|
||||
torch==1.12.1+cu116
|
||||
transformers==4.25.1
|
||||
wandb==0.13.7
|
||||
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "quality-scoring"
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
accuracy = evaluate.load("mse")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
return accuracy.compute(predictions=predictions.flatten(), references=labels.flatten())
|
||||
|
||||
|
||||
class QualityTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = nn.L1Loss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = self.sigmoid(outputs.get("logits"))
|
||||
loss = self.loss_fct(logits, labels)
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = inputs["labels"]
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
collate_fn = DataCollatorForSummaryScore(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
train = HFSummaryQuality(split="validation", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
eval = HFSummaryQuality(split="test", tokenizer=tokenizer, max_length=training_conf["max_length"])
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=len(train.label2idx), problem_type="regression"
|
||||
)
|
||||
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
trainer = QualityTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -0,0 +1,40 @@
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def test_hfsummary():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = HFSummary("train")
|
||||
print(len(dataset))
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)
|
||||
for batch in dataloader:
|
||||
batch["input_ids"].shape
|
||||
|
||||
|
||||
def test_webgpt():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=200)
|
||||
dataset = WebGPT()
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
dataset = HFSummaryQuality("validation", tokenizer)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
|
||||
accuracy = evaluate.load("accuracy")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "rank"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return accuracy.compute(predictions=predictions, references=[0] * predictions.shape[0])
|
||||
|
||||
|
||||
class RankLoss(nn.Module):
|
||||
def __init__(self, eps=1e-8) -> None:
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.log_sigmoid = nn.LogSigmoid()
|
||||
|
||||
def forward(self, pos, neg):
|
||||
return -self.log_sigmoid(pos - neg + self.eps).mean()
|
||||
|
||||
|
||||
class RankTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits").view(-1, 2)
|
||||
if self.loss_function == "rank":
|
||||
loss = self.loss_fct(logits[:, 0], logits[:, 1])
|
||||
else:
|
||||
loss = self.loss_fct(logits, torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long))
|
||||
|
||||
return loss, logits
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
# compute loss on predict data
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
loss = loss.mean().detach()
|
||||
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
training_conf = argument_parsing(parser)
|
||||
|
||||
model_name = training_conf["model_name"]
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, problem_type="regression")
|
||||
if "freeze_layer" in training_conf:
|
||||
num_layer = training_conf["freeze_layer"]
|
||||
model = freeze_top_n_layers(model, num_layer)
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
args = CustomTrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=True,
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
per_device_eval_batch_size=training_conf["per_device_eval_batch_size"],
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=2.0,
|
||||
logging_steps=10,
|
||||
save_total_limit=4,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
if "hfsummary" in training_conf["datasets"]:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
collate_fn = DataCollatorForPairRank(
|
||||
tokenizer, max_length=training_conf["max_length"], drop_token_type="galactica" in model_name
|
||||
)
|
||||
assert len(evals) > 0
|
||||
trainer = RankTrainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -0,0 +1,99 @@
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
|
||||
|
||||
def webgpt_return_format(row):
|
||||
if row["score_0"] >= row["score_1"]:
|
||||
# remove this to prevent information leak, since we are not using reference
|
||||
return {
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_0"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_1"]),
|
||||
}
|
||||
|
||||
return {
|
||||
"question": row["question"]["full_text"],
|
||||
"pos": re_reference_remove.sub("", row["answer_1"]),
|
||||
"neg": re_reference_remove.sub("", row["answer_0"]),
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if "galactica" in tokenizer_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
# [3879, 11479, 8341, 9177, 10798, 18177, 5735, 15669, 4837, 2760]
|
||||
print(val_idx[:10])
|
||||
# [13582, 5919, 11875, 7373, 19135, 13706, 8555, 15788, 15005, 15209]
|
||||
print(train_idx[:10])
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def freeze_top_n_layers(model, target_layers):
|
||||
# its possible we can simply detect which module is a ModuleList
|
||||
# and simply freeze the module without doing string parsing
|
||||
for name, param in model.named_parameters():
|
||||
if "embed" in name:
|
||||
param.requires_grad = False
|
||||
elif ".layer" in name or ".h." in name:
|
||||
tokens = name.split(".")
|
||||
idx = 0
|
||||
for token in tokens:
|
||||
if "layer" in token or token == "h":
|
||||
break
|
||||
idx += 1
|
||||
if idx >= len(tokens):
|
||||
continue
|
||||
|
||||
layer_ = int(tokens[idx + 1])
|
||||
if layer_ < target_layers:
|
||||
# print('freeze ', layer_, name)
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
default_params = {
|
||||
"num_train_epochs": 4,
|
||||
"learning_rate": 3e-5,
|
||||
"eval_steps": 500,
|
||||
"loss": "rank",
|
||||
"max_length": 440,
|
||||
"per_device_eval_batch_size": 5,
|
||||
"per_device_train_batch_size": 8,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"gradient_checkpointing": False,
|
||||
"datasets": ["webgpt"],
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
params = {**default_params, **training_conf}
|
||||
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
|
||||
params["num_train_epochs"] = int(params["num_train_epochs"])
|
||||
params["per_device_train_batch_size"] = int(params["per_device_train_batch_size"])
|
||||
params["learning_rate"] = float(params["learning_rate"])
|
||||
return params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bigscience/bloomz-560m")
|
||||
freeze_top_n_layers(model, 10)
|
||||
print(model.state_dict().keys())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user