Merge remote-tracking branch 'upstream/main'

This commit is contained in:
rsandb
2022-12-26 11:11:39 -06:00
55 changed files with 670 additions and 126 deletions
+6 -2
View File
@@ -3,7 +3,10 @@ name: Build
on:
workflow_call:
inputs:
folder:
dockerfile:
required: true
type: string
context:
required: true
type: string
image-name:
@@ -48,7 +51,8 @@ jobs:
- name: Build and push Docker image
uses: docker/build-push-action@v3.2.0
with:
context: ${{ inputs.folder }}
file: ${{ inputs.dockerfile }}
context: ${{ inputs.context }}
build-args: ${{ inputs.build-args }}
push: true
tags: ${{ steps.meta.outputs.tags }}
+35 -1
View File
@@ -9,5 +9,39 @@ jobs:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-backend
folder: backend
context: .
dockerfile: docker/Dockerfile.backend
build-args: ""
build-web:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-web
context: .
dockerfile: docker/Dockerfile.website
build-args: ""
build-bot:
uses: ./.github/workflows/docker-build.yaml
with:
image-name: oasst-discord-bot
context: .
dockerfile: docker/Dockerfile.discord-bot
build-args: ""
deploy-dev:
needs: [build-backend, build-web, build-bot]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Run playbook
uses: dawidd6/action-ansible-playbook@v2
with:
# Required, playbook filepath
playbook: dev.yaml
# Optional, directory where playbooks live
directory: ansible
# Optional, SSH private key
key: ${{secrets.DEV_NODE_PRIVATE_KEY}}
# Optional, literal inventory file contents
inventory: |
[dev]
dev01 ansible_host=${{secrets.DEV_NODE_IP}} ansible_connection=ssh ansible_user=web-team
+10
View File
@@ -46,3 +46,13 @@ repos:
hooks:
- id: prettier
args: ["--write"]
- repo: local
hooks:
- id: next-lint-website
name: Lint website
files: ^website/
types_or: [javascript, jsx, ts, tsx]
language: system
pass_filenames: false
entry: bash -c 'cd website && npm install && npm run lint'
+17 -27
View File
@@ -4,6 +4,20 @@ Open Assistant is a project meant to give everyone access to a great chat based
We believe that by doing this we will create a revolution in innovation in language. In the same way that stable-diffusion helped the world make art and images in new ways we hope Open Assistant can help improve the world by improving language itself.
## The Plan
We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
1. Collect high-quality human generated Instruction-Fulfillment samples (prompt + response), goal >50k. We design a crowdsourced process to collect and reviewed prompts. We do not want to train on flooding/toxic/spam/junk/personal information data. We will have a leaderboard to motivate the community that shows progress and the most active users. Swag will be given to the top-contributors.
2. For each of the collected prompts we will sample multiple completions. Completions of one prompt will then be shown randomly to users to rank them from best to worst. Again this should happen crowd-sourced, e.g. we need to deal with unreliable potentially malicious users. At least multiple votes by independent users have to be collected to measure the overall agreement. The gathered ranking-data will be used to train a reward model.
3. Now follows the RLHF training phase based on the prompts and the reward model.
We can then take the resulting model and continue with completion sampling step 2 for a next iteration.
## The Vision
We are not going to stop at replicating ChatGPT. We want to build the assistant of the future, able to not only write email and cover letters, but do meaningful work, use APIs, dynamically research information, and much more, with the ability to be personalized and extended by anyone. And we want to do this in a way that is open and accessible, which means we must not only build a great assistant, but also make it small and efficient enough to run on consumer hardware.
## How can you help?
All open source projects begins with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity.
@@ -12,7 +26,7 @@ All open source projects begins with people like you. Open source is the belief
[Fill out the contributor signup form](https://docs.google.com/forms/d/e/1FAIpQLSeuggO7UdYkBvGLEJldDvxp6DwaRbW5p7dl96UzFkZgziRTrQ/viewform)
[Join the LAION Discord Server!](https://discord.gg/RQFtmAmk)
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
[Visit the Notion](https://ykilcher.com/open-assistant)
@@ -43,30 +57,6 @@ Install `pre-commit` and run `pre-commit install` to install the pre-commit hook
In case you haven't done this, have already committed, and CI is failing, you can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
# (Older version of the readme below)
### Deployment
## How do I start helping out?
Check out these pages to learn more about the project.
Ping Birger on discord if you want help to get started.
http://**discordapp.com/users/birger#6875**
## More information in the notion
https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d716fa46
## Code structure
### Bot
We have a folder named bot where code related to the bot lives.
### Backend
We have a backend folder for backend development of the api that the discord bot sends it information to.
### Website
We have a folder for the website, live at https://projects.laion.ai/Open-Chat-GPT/ .The website is built using Next.js
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
+1
View File
@@ -0,0 +1 @@
*.local.yaml
+77
View File
@@ -0,0 +1,77 @@
# ansible playbook to set up some docker containers
- name: Set up a dev node
hosts: dev
gather_facts: true
tasks:
- name: Create network
community.docker.docker_network:
name: oasst
state: present
driver: bridge
- name: Create postgres containers
community.docker.docker_container:
name: "{{ item.name }}"
image: postgres:15
state: started
restart_policy: always
network_mode: oasst
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
volumes:
- "{{ item.name }}:/var/lib/postgresql/data"
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
loop:
- name: oasst-postgres
- name: oasst-postgres-web
- name: Set up maildev
community.docker.docker_container:
name: oasst-maildev
image: maildev/maildev
state: started
restart_policy: always
network_mode: oasst
- name: Run the oasst oasst-backend
community.docker.docker_container:
name: oasst-backend
image: ghcr.io/laion-ai/open-assistant/oasst-backend
state: started
pull: true
restart_policy: always
network_mode: oasst
env:
POSTGRES_HOST: oasst-postgres
DEBUG_ALLOW_ANY_API_KEY: "true"
MAX_WORKERS: "1"
ports:
- 8080:8080
- name: Run the oasst oasst-web frontend
community.docker.docker_container:
name: oasst-web
image: ghcr.io/laion-ai/open-assistant/oasst-web
state: started
pull: true
restart_policy: always
network_mode: oasst
env:
FASTAPI_URL: http://oasst-backend:8080
FASTAPI_KEY: "123"
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
EMAIL_SERVER_HOST: oasst-maildev
EMAIL_SERVER_PORT: "25"
EMAIL_FROM: info@example.com
NEXTAUTH_URL: http://web.dev.open-assistant.io/
ports:
- 3000:3000
command: bash wait-for-postgres.sh node server.js
+1
View File
@@ -7,6 +7,7 @@ Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Adds text labels table.
Revision ID: 067c4002f2d9
Revises: 0daec5f8135f
Create Date: 2022-12-25 17:05:21.208843
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "067c4002f2d9"
down_revision = "0daec5f8135f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"text_labels",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["post_id"],
["post.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("text_labels")
# ### end Alembic commands ###
+16 -16
View File
@@ -37,21 +37,21 @@ def api_auth(
db: Session,
) -> ApiClient:
if api_key is not None:
if settings.ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token")
db.add(api_client)
db.commit()
return api_client
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
if api_client is not None and api_client.enabled:
return api_client
+2 -1
View File
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks
from oasst_backend.api.v1 import tasks, text_labels
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
import pydantic
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
class LabelTextRequest(pydantic.BaseModel):
text_labels: protocol_schema.TextLabels
user: protocol_schema.User
@router.post("/")
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: LabelTextRequest,
) -> None:
"""
Label a piece of text.
"""
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {request=}.")
pr = PromptRepository(db, api_client, user=request.user)
pr.store_text_labels(request.text_labels)
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
+2 -1
View File
@@ -15,7 +15,8 @@ class Settings(BaseSettings):
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
ALLOW_ANY_API_KEY: bool = False
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
+2
View File
@@ -4,6 +4,7 @@ from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .text_labels import TextLabels
from .work_package import WorkPackage
__all__ = [
@@ -13,4 +14,5 @@ __all__ = [
"Post",
"PostReaction",
"WorkPackage",
"TextLabels",
]
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class TextLabels(SQLModel, table=True):
__tablename__ = "text_labels"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
+15 -1
View File
@@ -5,7 +5,7 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -314,3 +314,17 @@ class PromptRepository:
self.db.commit()
self.db.refresh(reaction)
return reaction
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
model = TextLabels(
api_client_id=self.api_client.id,
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_post_id:
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
model.post_id = text_labels.post_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
+1
View File
@@ -5,6 +5,7 @@ numpy==1.22.4
psycopg2-binary==2.9.5
pydantic==1.9.1
python-dotenv==0.21.0
scipy==1.8.1
SQLAlchemy==1.4.41
sqlmodel==0.0.8
starlette==0.22.0
-3
View File
@@ -1,3 +0,0 @@
:robot: **Challenge: Assistant Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
-3
View File
@@ -1,3 +0,0 @@
:microphone2: **Challenge: Initial Prompt**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+1 -2
View File
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from bot_settings import settings
from bot import OpenAssistantBot
from bot_settings import settings
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
View File
@@ -10,6 +10,6 @@ Commands for bot owners:
`!sync`
`!sync.guild`
`!sync.copy_global`
`!sync.copy_global`
`!sync.clear_guild`
{% endif %}
{% endif %}
@@ -9,4 +9,4 @@ Here is the conversation so far:
**{{ message.text }}**"
{% endif %}
{% endfor %}
:robot: Assistant: { human, pls help me! ... }
:robot: Assistant: { human, pls help me! ... }
@@ -1,4 +1,4 @@
Please provide an initial prompt to the assistant.
{% if task.hint is not none %}
Hint: {{task.hint}}
{% endif %}
{% endif %}
@@ -10,4 +10,4 @@ Rank the following replies:
{% for reply in task.replies %}
{{loop.index}}: {{reply}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
@@ -2,4 +2,4 @@ Rank the following prompts:
{% for prompt in task.prompts %}
{{loop.index}}: {{prompt}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
@@ -9,4 +9,4 @@ Here is the conversation so far:
{% endif %}{% endfor %}
{% if task.hint %}
Hint: {{ task.hint }}
{% endif %}
{% endif %}
@@ -0,0 +1,3 @@
:robot: **Challenge: Assistant Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:microphone2: **Challenge: Initial Prompt**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -1,3 +1,3 @@
:bar_chart: **Challenge: Rank Replies**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -1,3 +1,3 @@
:bar_chart: **Challenge: Rank Initial Prompts**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -1,3 +1,3 @@
:ballot_box: **Challenge: Rate Summary**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -1,3 +1,3 @@
:books: **Challenge: Summarize Story**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -1,3 +1,3 @@
:person_red_hair: **Challenge: User Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -204,3 +204,51 @@ AnyInteraction = Union[
PostRating,
PostRanking,
]
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
spam = "spam"
violence = "violence"
sexual_content = "sexual_content"
toxicity = "toxicity"
political_content = "political_content"
humor = "humor"
sarcasm = "sarcasm"
hate_speech = "hate_speech"
profanity = "profanity"
ad_hominem = "ad_hominem"
insult = "insult"
threat = "threat"
aggressive = "aggressive"
misleading = "misleading"
helpful = "helpful"
formal = "formal"
cringe = "cringe"
creative = "creative"
beautiful = "beautiful"
informative = "informative"
based = "based"
slang = "slang"
class TextLabels(BaseModel):
"""A set of labels for a piece of text."""
text: str
labels: dict[TextLabel, float]
post_id: str | None = None
@property
def has_post_id(self) -> bool:
"""Whether this TextLabels has a post_id."""
return bool(self.post_id)
# check that each label value is between 0 and 1
@pydantic.validator("labels")
def check_label_values(cls, v):
for key, value in v.items():
if not (0 <= value <= 1):
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
return v
@@ -9,6 +9,11 @@ services:
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
adminer:
image: adminer
+1 -1
View File
@@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../backend"
export ALLOW_ANY_API_KEY=True
export DEBUG_SKIP_API_KEY_CHECK=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0
+15 -49
View File
@@ -3,69 +3,35 @@ version: "3.7"
services:
# This DB is for the FastAPI Backend.
db:
image: postgres
restart: always
ports:
- 5432:5432
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
extends:
file: ../frontend-development/docker-compose.yaml
service: db
# This DB is for Web Authentication and data caching.
webdb:
image: postgres
restart: always
ports:
- 5433:5432
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
extends:
file: ../frontend-development/docker-compose.yaml
service: webdb
# This lets you manually inspect the web and backend databases.
adminer:
image: adminer
restart: always
ports:
- 8089:8080
extends:
file: ../frontend-development/docker-compose.yaml
service: adminer
# This fakes an SMTP email server used by website authentication.
# User registration emails can be found by going to localhost:1080 and
# opening the emails listed.
maildev:
image: maildev/maildev
restart: always
environment:
- MAILDEV_WEB_PORT=1080
- MAILDEV_SMTP_PORT=1025
ports:
- "1080:1080"
- "1025:1025"
extends:
file: ../frontend-development/docker-compose.yaml
service: maildev
# The oassist backend service.
backend:
build:
dockerfile: docker/Dockerfile.backend
context: ../../
image: oasst-backend
environment:
- POSTGRES_HOST=db
- ALLOW_ANY_API_KEY=True
- MAX_WORKERS=1
depends_on:
db:
condition: service_healthy
ports:
- "8080:8080"
extends:
file: ../frontend-development/docker-compose.yaml
service: backend
# The oassist web service.
web:
@@ -6,11 +6,6 @@ services:
extends:
file: ../backend-development/docker-compose.yaml
service: db
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
# This DB is for Web Authentication and data caching.
webdb:
@@ -32,6 +27,7 @@ services:
extends:
file: ../backend-development/docker-compose.yaml
service: adminer
backend:
build:
dockerfile: docker/Dockerfile.backend
@@ -39,7 +35,7 @@ services:
image: oasst-backend
environment:
- POSTGRES_HOST=db
- ALLOW_ANY_API_KEY=True
- DEBUG_SKIP_API_KEY_CHECK=True
- MAX_WORKERS=1
depends_on:
db:
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
import numpy as np
from scipy import log2
from scipy.integrate import nquad
from scipy.special import gammaln, psi
from scipy.stats import dirichlet
def make_range(*x):
"""
constructs leftover values for the simplex given the first k entries
(0,x_k) = 1-(x_1+...+x_(k-1))
"""
return (0, max(0, 1 - sum(x)))
def relative_entropy(p, q):
"""
relative entropy of the two given dirichlet distributions
"""
def tmp(*x):
"""
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
Then computes the relative entropy of posterior and prior for that datapoint
"""
x_new = np.append(x, 1 - sum(x))
return p(x_new) * log2(p(x_new) / q(x_new))
return tmp
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
s = np.random.rand(dim - 1, samples)
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
# print(s)
pos = np.diff(s, axis=0)
# print(pos)
res = fun(pos)
return np.mean(res)
def analytic_solution(a_post, a_prior):
"""
Analytic solution to the KL-divergence between two dirichlet distributions.
Proof is in the Notion design doc.
"""
post_sum = np.sum(a_post)
prior_sum = np.sum(a_prior)
info = (
gammaln(post_sum)
- gammaln(prior_sum)
- np.sum(gammaln(a_post))
+ np.sum(gammaln(a_prior))
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
)
return info
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
def uniform_expected_infogain(a_prior):
mean_weight = dirichlet.mean(a_prior)
print("weight", mean_weight)
results = []
for i, w in enumerate(mean_weight):
a_post = a_prior.copy()
a_post[i] = a_post[i] + 1
results.append(w * analytic_solution(a_post, a_prior))
return np.sum(results)
if __name__ == "__main__":
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print("algebraic", analytic_solution(a_post, a_prior))
# print("raw",infogain(a_post, a_prior))
print("large infogain", uniform_expected_infogain(a_prior))
print("post infogain", uniform_expected_infogain(a_post))
# a_prior = np.array([1,1,1000])
# print("small infogain",uniform_expected_infogain(a_prior))
+183
View File
@@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
from dataclasses import dataclass, replace
from typing import Any
import numpy as np
import numpy.typing as npt
from scipy.stats import kendalltau
@dataclass
class Voter:
"""
Represents a single voter.
This tabulates the number of good votes, total votes,
and points.
We only put well-behaved people on the scoreboard and filter out the badly behaved ones
"""
uid: Any
num_votes: int
num_good_votes: int
num_prompts: int
num_good_prompts: int
num_rankings: int
num_good_rankings: int
#####################
voting_points: int
prompt_points: int
ranking_points: int
def voter_quality(self):
return self.num_good_votes / self.num_votes
def rank_quality(self):
return self.num_good_rankings / self.num_rankings
def prompt_quality(self):
return self.num_good_prompts / self.num_prompts
def is_well_behaved(self, threshhold_vote, threshhold_prompt, threshhold_rank):
return (
self.voter_quality() > threshhold_vote
and self.prompt_quality() > threshhold_prompt
and self.rank_quality() > threshhold_rank
)
def total_points(self, voting_weight, prompt_weight, ranking_weight):
return (
voting_weight * self.voting_points
+ prompt_weight * self.prompt_points
+ ranking_weight * self.ranking_points
)
def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the new "quality score" and points for a voter,
after that voter cast a vote on a question.
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
The consensus is the array of all votes cast by all voters for that question
We then update the voter data using the new information
Parameters:
new_vote (int): the index of the vote cast by the voter
consensus (ArrayLike): all votes cast for this question
voter_data (Voter): a "Voter" object that represents the person casting the "new_vote"
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
# since 100 is the lowest, 300 the highest and 200 the middle value
consensus_ranking = np.argsort(np.argsort(consensus))
new_points = consensus_ranking[new_vote] + voter_data.voting_points
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
# it's a good vote
new_good_votes = int(consensus_ranking[new_vote] > (len(consensus) - 1) / 2) + voter_data.num_good_votes
new_num_votes = voter_data.num_votes + 1
return replace(voter_data, num_votes=new_num_votes, num_good_votes=new_good_votes, voting_points=new_points)
def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the gain of points for a given prompt's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
Parameters:
consensus (ArrayLike): all votes cast for this question
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
# since 100 is the lowest, 300 the highest and 200 the middle value
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
delta_votes = np.sum(consensus_ranking * consensus)
new_points = delta_votes + voter_data.prompt_points
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
# it's a good vote
new_good_prompts = int(delta_votes > 0) + voter_data.num_good_prompts
new_num_prompts = voter_data.num_prompts + 1
return replace(
voter_data,
num_prompts=new_num_prompts,
num_good_prompts=new_good_prompts,
prompt_points=new_points,
)
def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the gain of points for a given ranking's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings
we use this over spearman correlation since:
"[Kendall's τ] approaches a normal distribution more rapidly than ρ, as N, the sample size, increases;
and τ is also more tractable mathematically, particularly when ties are present"
Gilpin, A. R. (1993). Table for conversion of Kendall's Tau to Spearman's
Rho within the context measures of magnitude of effect for meta-analysis
Further in
"research design and statistical analyses, second edition, 2003"
the authors note that at least from an significance test POV they will yield the same p-values
Parameters:
user_ranking (ArrayLike): ranking produced by the user
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
bubble_sort_distance, p_value = kendalltau(user_ranking, consensus_ranking)
# normalize kendall-tau from [-1,1] into [0,1] range
bubble_sort_distance = (1 + bubble_sort_distance) / 2
new_points = bubble_sort_distance + voter_data.ranking_points
new_good_rankings = int(bubble_sort_distance > 0.5) + voter_data.num_good_rankings
new_num_rankings = voter_data.num_rankings + 1
return replace(
voter_data,
num_rankings=new_num_rankings,
num_good_rankings=new_good_rankings,
ranking_points=new_points,
)
if __name__ == "__main__":
demo_voter = Voter(
"abc",
num_votes=10,
num_good_votes=2,
num_prompts=10,
num_good_prompts=2,
num_rankings=10,
num_good_rankings=2,
voting_points=6,
prompt_points=0,
ranking_points=0,
)
new_vote = 3
consensus = np.array([200, 300, 100, 500])
print(demo_voter)
print("best vote ", score_update_votes(new_vote, consensus, demo_voter))
new_vote = 2
print("worst vote ", score_update_votes(new_vote, consensus, demo_voter))
new_vote = 1
print("medium vote ", score_update_votes(new_vote, consensus, demo_voter))
print("prompt writer", score_update_prompts(consensus, demo_voter))
print("best rank ", score_update_ranking(np.array([0, 2, 1]), np.array([0, 2, 1]), demo_voter))
print("medium rank ", score_update_ranking(np.array([2, 0, 1]), np.array([0, 2, 1]), demo_voter))
print("worst rank ", score_update_ranking(np.array([1, 0, 2]), np.array([0, 2, 1]), demo_voter))
+1 -1
View File
@@ -16,7 +16,7 @@ export default function Error() {
</Head>
<Header />
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
Sorry, the page you're looking for doesn't exist.
{"Sorry, the page you're looking for does not exist."}
</main>
<Footer />
</>