mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into sft-data-sampling
This commit is contained in:
@@ -38,14 +38,31 @@ jobs:
|
||||
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
|
||||
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
|
||||
MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
|
||||
MAX_INITIAL_PROMPT_REVIEW: ${{ vars.MAX_INITIAL_PROMPT_REVIEW }}
|
||||
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
|
||||
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
|
||||
LONELY_CHILDREN_COUNT: ${{ vars.LONELY_CHILDREN_COUNT }}
|
||||
P_LONELY_CHILD_EXTENSION: ${{ vars.P_LONELY_CHILD_EXTENSION }}
|
||||
P_ACTIVATE_BACKLOG_TREE: ${{ vars.P_ACTIVATE_BACKLOG_TREE }}
|
||||
MIN_ACTIVE_RANKINGS_PER_LANG: ${{ vars.MIN_ACTIVE_RANKINGS_PER_LANG }}
|
||||
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
|
||||
MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }}
|
||||
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
|
||||
STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }}
|
||||
STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }}
|
||||
STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }}
|
||||
STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }}
|
||||
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
|
||||
${{ secrets.WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
|
||||
WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY:
|
||||
${{ secrets.WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
|
||||
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
|
||||
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
|
||||
LOGURU_LEVEL: ${{ vars.LOGURU_LEVEL }}
|
||||
MAINTENANCE_MODE: ${{ vars.MAINTENANCE_MODE }}
|
||||
BACKEND_URL: ${{ vars.BACKEND_URL }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
@@ -21,6 +21,8 @@ jobs:
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "pip"
|
||||
cache-dependency-path: "**/requirements*.txt"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- name: Post PR comment on failure
|
||||
if: failure() && github.event_name == 'pull_request_target'
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
name: Deploy to prod2
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- production2
|
||||
|
||||
jobs:
|
||||
deploy-to-prod:
|
||||
uses: ./.github/workflows/deploy-to-node.yaml
|
||||
secrets: inherit
|
||||
with:
|
||||
stack-name: production2
|
||||
image-tag: ${{ vars.PROD_IMAGE_TAG }}
|
||||
backend-port: 8280
|
||||
website-port: 3200
|
||||
@@ -23,6 +23,8 @@ jobs:
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "pip"
|
||||
cache-dependency-path: "**/requirements*.txt"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16
|
||||
|
||||
Vendored
+2
-1
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"python.formatting.provider": "black",
|
||||
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"]
|
||||
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"],
|
||||
"prettier.singleQuote": false
|
||||
}
|
||||
|
||||
+7
-7
@@ -13,12 +13,12 @@ dedicated, but not as active, channel.
|
||||
|
||||
### Taking on Tasks
|
||||
|
||||
We have a growing task list
|
||||
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
|
||||
that appeals to you and make a comment that you'd like to work on it. Include in
|
||||
your comment a brief description of how you'll solve the problem and if there
|
||||
are any open questions you want to discuss. Once a project coordinator has
|
||||
assigned the issue to you, start working on it.
|
||||
We have a growing task list of
|
||||
[issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue that
|
||||
appeals to you and make a comment that you'd like to work on it. Include in your
|
||||
comment a brief description of how you'll solve the problem and if there are any
|
||||
open questions you want to discuss. Once a project coordinator has assigned the
|
||||
issue to you, start working on it.
|
||||
|
||||
If the issue is currently unclear but you are interested, please post in Discord
|
||||
and someone can help clarify the issue with more detail.
|
||||
@@ -140,4 +140,4 @@ automatically deploy the built release to the dev machine.
|
||||
### Contribute a Dataset
|
||||
|
||||
See
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/openassistant/datasets/README.md)
|
||||
|
||||
@@ -14,6 +14,12 @@
|
||||
|
||||
</div>
|
||||
|
||||
# Here is our website to collect data:
|
||||
|
||||
[open-assistant.io](https://open-assistant.io)
|
||||
|
||||
(project documentation lives [here](https://laion-ai.github.io/Open-Assistant/))
|
||||
|
||||
# Table of Contents
|
||||
|
||||
- [What is Open Assistant?](#what-is-open-assistant)
|
||||
@@ -38,9 +44,22 @@ improving language itself.
|
||||
|
||||
## Do you want to try it out?
|
||||
|
||||
If you are interested in taking a look at the current state of the project, you
|
||||
### Contributing to Data Collection
|
||||
|
||||
The data collection frontend is now live [here](https://open-assistant.io/). Log
|
||||
in and start taking on tasks! We want to collect a high volume of quality data.
|
||||
By submitting, ranking, and labelling model prompts and responses you will be
|
||||
directly helping to improve the capabilities of Open Assistant.
|
||||
|
||||
### Running Locally
|
||||
|
||||
**You do not need to run the project locally unless you are contributing to the
|
||||
development process. The website link above will take you to the public website
|
||||
where you can use the data collection app.**
|
||||
|
||||
If you would like to run the data collection app locally for development, you
|
||||
can set up an entire stack needed to run **Open-Assistant**, including the
|
||||
website, backend, and associated dependent services.
|
||||
website, backend, and associated dependent services, with Docker.
|
||||
|
||||
##### To start the demo, run this in the root directory of the repository:
|
||||
|
||||
@@ -51,6 +70,10 @@ docker compose up --build
|
||||
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
|
||||
interact with the website.
|
||||
|
||||
> **Note:** If an issue occurs with the build, please head to the
|
||||
> [FAQ](https://projects.laion.ai/Open-Assistant/docs/faq) and check out the
|
||||
> entries about Docker.
|
||||
|
||||
> **Note:** When logging in via email, navigate to `http://localhost:1080` to
|
||||
> get the magic email login link.
|
||||
|
||||
@@ -65,7 +88,7 @@ interact with the website.
|
||||
|
||||
## The Plan
|
||||
|
||||
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
|
||||
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155).
|
||||
|
||||
1. Collect high-quality human generated Instruction-Fulfillment samples
|
||||
(prompt + response), goal >50k. We design a crowdsourced process to collect
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
To test the ansible playbook on localhost run
|
||||
`ansible-playbook -i test.inventory.ini dev.yaml`.\
|
||||
In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\
|
||||
In case you're missing the ansible docker dependency install it with `ansible-galaxy collection install community.docker`.\
|
||||
Point Redis Insights to the Redis database by visiting localhost:8001 in a
|
||||
browser and select "I already have a database" followed by "Connect to a Redis
|
||||
Database".\
|
||||
|
||||
@@ -113,6 +113,9 @@
|
||||
TREE_MANAGER__MAX_ACTIVE_TREES:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') |
|
||||
default('10', true) }}"
|
||||
TREE_MANAGER__MAX_INITIAL_PROMPT_REVIEW:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_INITIAL_PROMPT_REVIEW') |
|
||||
default('100', true) }}"
|
||||
TREE_MANAGER__MAX_TREE_DEPTH:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5',
|
||||
true) }}"
|
||||
@@ -122,6 +125,21 @@
|
||||
TREE_MANAGER__MAX_CHILDREN_COUNT:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') |
|
||||
default('3', true) }}"
|
||||
TREE_MANAGER__LONELY_CHILDREN_COUNT:
|
||||
"{{ lookup('ansible.builtin.env', 'LONELY_CHILDREN_COUNT') |
|
||||
default('2', true) }}"
|
||||
TREE_MANAGER__P_LONELY_CHILD_EXTENSION:
|
||||
"{{ lookup('ansible.builtin.env', 'P_LONELY_CHILD_EXTENSION') |
|
||||
default('0.75', true) }}"
|
||||
TREE_MANAGER__P_ACTIVATE_BACKLOG_TREE:
|
||||
"{{ lookup('ansible.builtin.env', 'P_ACTIVATE_BACKLOG_TREE') |
|
||||
default('0.1', true) }}"
|
||||
TREE_MANAGER__MIN_ACTIVE_RANKINGS_PER_LANG:
|
||||
"{{ lookup('ansible.builtin.env', 'MIN_ACTIVE_RANKINGS_PER_LANG') |
|
||||
default('0', true) }}"
|
||||
MESSAGE_SIZE_LIMIT:
|
||||
"{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') |
|
||||
default('2000', true) }}"
|
||||
USER_STATS_INTERVAL_DAY:
|
||||
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') |
|
||||
default('5', true) }}"
|
||||
@@ -134,6 +152,9 @@
|
||||
USER_STATS_INTERVAL_TOTAL:
|
||||
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_TOTAL') |
|
||||
default('240', true) }}"
|
||||
LOGURU_LEVEL:
|
||||
"{{ lookup('ansible.builtin.env', 'LOGURU_LEVEL') | default('INFO',
|
||||
true) }}"
|
||||
ports:
|
||||
- "{{ backend_port }}:8080"
|
||||
|
||||
@@ -165,13 +186,27 @@
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PORT') }}"
|
||||
EMAIL_SERVER_USER:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
|
||||
FASTAPI_URL: "http://oasst-{{ stack_name }}-backend:8080"
|
||||
FASTAPI_URL: "{{ lookup('ansible.builtin.env', 'BACKEND_URL') }}"
|
||||
FASTAPI_KEY: "{{ web_api_key }}"
|
||||
NEXTAUTH_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
|
||||
NEXTAUTH_URL:
|
||||
"{{ 'https://open-assistant.io/' if stack_name == 'production' else
|
||||
('https://web.' + stack_name + '.open-assistant.io/') }}"
|
||||
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY') }}"
|
||||
CLOUDFLARE_CAPTCHA_SECRET_KEY:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY') }}"
|
||||
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA') }}"
|
||||
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN') }}"
|
||||
MAINTENANCE_MODE:
|
||||
"{{ lookup('ansible.builtin.env', 'MAINTENANCE_MODE') }}"
|
||||
ports:
|
||||
- "{{ website_port }}:3000"
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
|
||||
+3
-2
@@ -36,6 +36,7 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user_stats", "streak_days")
|
||||
op.drop_column("user_stats", "streak_last_day_date")
|
||||
op.drop_column("user", "streak_days")
|
||||
op.drop_column("user", "streak_last_day_date")
|
||||
op.drop_column("user", "last_activity_date")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
"""add tos_acceptance_date to user
|
||||
|
||||
Revision ID: 55361f323d12
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 00:22:08.280251
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55361f323d12"
|
||||
down_revision = "f60958968ff8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
|
||||
op.drop_column("user_stats", "streak_days")
|
||||
op.drop_column("user_stats", "streak_last_day_date")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
|
||||
)
|
||||
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
|
||||
op.drop_column("user", "tos_acceptance_date")
|
||||
# ### end Alembic commands ###
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
"""add won_prompt_lottery_date to mts
|
||||
|
||||
Revision ID: f60958968ff8
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 10:10:38.301707
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f60958968ff8"
|
||||
down_revision = "7b8f0011e0b0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message_tree_state", sa.Column("won_prompt_lottery_date", sa.DateTime(timezone=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_tree_state", "won_prompt_lottery_date")
|
||||
# ### end Alembic commands ###
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
"""add skip bool & skip_reason to task
|
||||
|
||||
Revision ID: 9e7ec4a9e3f2
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 21:46:49.971052
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9e7ec4a9e3f2"
|
||||
down_revision = "55361f323d12"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("task", sa.Column("skipped", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("task", sa.Column("skip_reason", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("task", "skip_reason")
|
||||
op.drop_column("task", "skipped")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,59 @@
|
||||
"""add troll_stats
|
||||
|
||||
Revision ID: 4d7e0b0ebe84
|
||||
Revises: 9e7ec4a9e3f2
|
||||
Create Date: 2023-02-02 15:44:12.647260
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d7e0b0ebe84"
|
||||
down_revision = "9e7ec4a9e3f2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"troll_stats",
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("troll_score", sa.Integer(), nullable=False),
|
||||
sa.Column("rank", sa.Integer(), nullable=True),
|
||||
sa.Column("red_flags", sa.Integer(), nullable=False),
|
||||
sa.Column("upvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("downvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("spam_prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("quality", sa.Float(), nullable=True),
|
||||
sa.Column("humor", sa.Float(), nullable=True),
|
||||
sa.Column("toxicity", sa.Float(), nullable=True),
|
||||
sa.Column("violence", sa.Float(), nullable=True),
|
||||
sa.Column("helpfulness", sa.Float(), nullable=True),
|
||||
sa.Column("spam", sa.Integer(), nullable=False),
|
||||
sa.Column("lang_mismach", sa.Integer(), nullable=False),
|
||||
sa.Column("not_appropriate", sa.Integer(), nullable=False),
|
||||
sa.Column("pii", sa.Integer(), nullable=False),
|
||||
sa.Column("hate_speech", sa.Integer(), nullable=False),
|
||||
sa.Column("sexual_content", sa.Integer(), nullable=False),
|
||||
sa.Column("political_content", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "time_frame"),
|
||||
)
|
||||
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
|
||||
op.drop_table("troll_stats")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Add Account table
|
||||
|
||||
Revision ID: 8c8241d1f973
|
||||
Revises: 4d7e0b0ebe84
|
||||
Create Date: 2023-01-30 15:10:58.776315
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8c8241d1f973"
|
||||
down_revision = "4d7e0b0ebe84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"account",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("provider", "account", ["provider_account_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("provider", table_name="account")
|
||||
op.drop_table("account")
|
||||
# ### end Alembic commands ###
|
||||
+41
@@ -0,0 +1,41 @@
|
||||
"""Added new table for flagged messages
|
||||
|
||||
Revision ID: caee1e8ee0bc
|
||||
Revises: 8c8241d1f973
|
||||
Create Date: 2023-02-07 19:22:12.696257
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "caee1e8ee0bc"
|
||||
down_revision = "8c8241d1f973"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"flagged_message",
|
||||
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("processed", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("message_id"),
|
||||
)
|
||||
op.create_index(op.f("ix_flagged_message_created_date"), "flagged_message", ["created_date"], unique=False)
|
||||
op.create_index(op.f("ix_flagged_message_processed"), "flagged_message", ["processed"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_flagged_message_processed"), table_name="flagged_message")
|
||||
op.drop_index(op.f("ix_flagged_message_created_date"), table_name="flagged_message")
|
||||
op.drop_table("flagged_message")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,218 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.models import Message, MessageTreeState
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_backend.utils import tree_export
|
||||
from sqlmodel import Session, not_
|
||||
|
||||
|
||||
def fetch_tree_ids(
|
||||
db: Session,
|
||||
state_filter: Optional[TreeState] = None,
|
||||
lang: Optional[str] = None,
|
||||
) -> list[tuple[UUID, TreeState]]:
|
||||
tree_qry = (
|
||||
db.query(MessageTreeState)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
)
|
||||
|
||||
if lang is not None:
|
||||
tree_qry = tree_qry.filter(Message.lang == lang)
|
||||
|
||||
if state_filter:
|
||||
tree_qry = tree_qry.filter(MessageTreeState.state == state_filter)
|
||||
|
||||
return [(tree.message_tree_id, tree.state) for tree in tree_qry]
|
||||
|
||||
|
||||
def fetch_tree_messages(
|
||||
db: Session,
|
||||
message_tree_id: Optional[UUID] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
deleted: bool = None,
|
||||
prompts_only: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
review_result: Optional[bool] = None,
|
||||
) -> List[Message]:
|
||||
qry = db.query(Message)
|
||||
|
||||
if message_tree_id:
|
||||
qry = qry.filter(Message.message_tree_id == message_tree_id)
|
||||
if user_id:
|
||||
qry = qry.filter(Message.user_id == user_id)
|
||||
if deleted is not None:
|
||||
qry = qry.filter(Message.deleted == deleted)
|
||||
if prompts_only:
|
||||
qry = qry.filter(Message.parent_id.is_(None))
|
||||
if lang:
|
||||
qry = qry.filter(Message.lang == lang)
|
||||
if review_result is False:
|
||||
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
|
||||
elif review_result is True:
|
||||
qry = qry.filter(Message.review_result)
|
||||
|
||||
return qry.all()
|
||||
|
||||
|
||||
def export_trees(
|
||||
db: Session,
|
||||
export_file: Optional[Path] = None,
|
||||
use_compression: bool = False,
|
||||
deleted: bool = False,
|
||||
user_id: Optional[UUID] = None,
|
||||
prompts_only: bool = False,
|
||||
state_filter: Optional[TreeState] = None,
|
||||
lang: Optional[str] = None,
|
||||
review_result: Optional[bool] = None,
|
||||
) -> None:
|
||||
trees_to_export: List[tree_export.ExportMessageTree] = []
|
||||
|
||||
if user_id or review_result is False:
|
||||
messages = fetch_tree_messages(
|
||||
db,
|
||||
user_id=user_id,
|
||||
deleted=deleted,
|
||||
prompts_only=prompts_only,
|
||||
lang=lang,
|
||||
review_result=review_result,
|
||||
)
|
||||
tree_export.write_messages_to_file(export_file, messages, use_compression)
|
||||
else:
|
||||
message_tree_ids = fetch_tree_ids(db, state_filter, lang=lang)
|
||||
|
||||
message_trees = [
|
||||
fetch_tree_messages(
|
||||
db,
|
||||
message_tree_id=tree_id,
|
||||
deleted=deleted,
|
||||
prompts_only=prompts_only,
|
||||
lang=None,
|
||||
review_result=review_result,
|
||||
)
|
||||
for (tree_id, _) in message_tree_ids
|
||||
]
|
||||
|
||||
# when exporting only-deleted we don't have a porper tree structure, export as list
|
||||
if deleted is True:
|
||||
messages = [m for t in message_trees for m in t]
|
||||
tree_export.write_messages_to_file(export_file, messages, use_compression)
|
||||
else:
|
||||
for (message_tree_id, message_tree_state), message_tree in zip(message_tree_ids, message_trees):
|
||||
t = tree_export.build_export_tree(message_tree_id, message_tree_state, message_tree)
|
||||
if prompts_only:
|
||||
t.prompt.replies = None
|
||||
trees_to_export.append(t)
|
||||
|
||||
tree_export.write_trees_to_file(export_file, trees_to_export, use_compression)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
if args.deleted_only:
|
||||
args.include_deleted = True
|
||||
|
||||
args.use_compression = args.export_file is not None and ".gz" in args.export_file
|
||||
|
||||
if args.state and args.user is not None:
|
||||
raise ValueError("Cannot use --state when specifying a user ID")
|
||||
|
||||
if args.export_file is None:
|
||||
logger.warning("No export file provided, output will be sent to STDOUT")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--export-file",
|
||||
type=str,
|
||||
help="Name of file to export trees to. If not provided, output will be sent to STDOUT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-deleted",
|
||||
action="store_true",
|
||||
help="Include deleted messages in export",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deleted-only",
|
||||
action="store_true",
|
||||
help="Export only deleted messages (implies --include-deleted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-spam",
|
||||
action="store_true",
|
||||
help="Export only messages with negative review result.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spam-only",
|
||||
action="store_true",
|
||||
help="Export only messages with negative review result (implies --include-spam).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
type=str,
|
||||
help="Only export trees involving the user with the specified ID. Incompatible with --state.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state",
|
||||
type=str,
|
||||
help="all|prompt_lottery_waiting|growing|ready_for_export|aborted_low_grade|halted_by_moderator|backlog_ranking",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang",
|
||||
type=str,
|
||||
help="Filter message trees by language code (BCP 47)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompts-only",
|
||||
action="store_true",
|
||||
help="Export a list of initial prompt messages",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
validate_args(args)
|
||||
|
||||
state_filter: Optional[TreeState] = None
|
||||
if args.state is None:
|
||||
state_filter = TreeState.READY_FOR_EXPORT
|
||||
elif args.state != "all":
|
||||
state_filter = TreeState(args.state)
|
||||
|
||||
deleted: Optional[bool] = False
|
||||
if args.include_deleted:
|
||||
deleted = None
|
||||
if args.deleted_only:
|
||||
deleted = True
|
||||
|
||||
review_result: Optional[bool] = True
|
||||
if args.include_spam:
|
||||
review_result = None
|
||||
if args.spam_only:
|
||||
review_result = False
|
||||
|
||||
with Session(engine) as db:
|
||||
export_trees(
|
||||
db,
|
||||
Path(args.export_file) if args.export_file is not None else None,
|
||||
args.use_compression,
|
||||
deleted=deleted,
|
||||
user_id=UUID(args.user) if args.user is not None else None,
|
||||
prompts_only=args.prompts_only,
|
||||
state_filter=state_filter,
|
||||
lang=args.lang,
|
||||
review_result=review_result,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+11
-35
@@ -18,7 +18,8 @@ from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, UserRepository
|
||||
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_repository import User
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
@@ -147,6 +148,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
ur = UserRepository(db=session, api_client=api_client)
|
||||
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
|
||||
pr = PromptRepository(
|
||||
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
@@ -317,6 +319,13 @@ def update_user_streak(session: Session) -> None:
|
||||
return
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * 60) # 1 hour
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def cronjob_delete_expired_tasks(session: Session) -> None:
|
||||
delete_expired_tasks(session)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@@ -324,24 +333,6 @@ def get_openapi_schema():
|
||||
return json.dumps(app.openapi())
|
||||
|
||||
|
||||
def export_ready_trees(file: Optional[str] = None, use_compression: bool = False):
|
||||
try:
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(db, pr)
|
||||
|
||||
tm.export_all_ready_trees(file, use_compression=use_compression)
|
||||
except Exception:
|
||||
logger.exception("Error exporting trees.")
|
||||
|
||||
|
||||
def retry_scoring_failed_message_trees():
|
||||
try:
|
||||
logger.info("TreeManager.retry_scoring_failed_message_trees()")
|
||||
@@ -373,17 +364,6 @@ def main():
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
|
||||
parser.add_argument("--port", help="The port to run the server", default=8080)
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
default=False,
|
||||
help="Export all trees which are ready for exporting.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-file",
|
||||
type=str,
|
||||
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-scoring",
|
||||
default=False,
|
||||
@@ -396,14 +376,10 @@ def main():
|
||||
if args.print_openapi_schema:
|
||||
print(get_openapi_schema())
|
||||
|
||||
if args.export:
|
||||
use_compression: bool = ".gz" in args.export_file
|
||||
export_ready_trees(file=args.export_file, use_compression=use_compression)
|
||||
|
||||
if args.retry_scoring:
|
||||
retry_scoring_failed_message_trees()
|
||||
|
||||
if not (args.export or args.print_openapi_schema or args.retry_scoring):
|
||||
if not (args.print_openapi_schema or args.retry_scoring):
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
@@ -162,3 +163,37 @@ async def purge_user_messages(
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
|
||||
class FlaggedMessageResponse(pydantic.BaseModel):
|
||||
message_id: UUID
|
||||
processed: bool
|
||||
created_date: Optional[datetime]
|
||||
|
||||
|
||||
@router.get("/flagged_messages", response_model=list[FlaggedMessageResponse])
|
||||
async def get_flagged_messages(
|
||||
max_count: Optional[int],
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
pr = PromptRepository(session, api_client)
|
||||
flagged_messages = pr.fetch_flagged_messages(max_count=max_count)
|
||||
resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages]
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
|
||||
async def process_flagged_messages(
|
||||
message_id: UUID,
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
pr = PromptRepository(session, api_client)
|
||||
flagged_msg = pr.process_flagged_message(message_id=message_id)
|
||||
resp = FlaggedMessageResponse(**flagged_msg.__dict__)
|
||||
return resp
|
||||
|
||||
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
|
||||
stats,
|
||||
tasks,
|
||||
text_labels,
|
||||
trollboards,
|
||||
users,
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
|
||||
@@ -59,6 +59,37 @@ def query_frontend_user(
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol.FrontEndUser)
|
||||
def create_frontend_user(
|
||||
*,
|
||||
create_user: protocol.CreateFrontendUserRequest,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
user = ur.lookup_client_user(create_user, create_missing=True)
|
||||
|
||||
def changed(a, b) -> bool:
|
||||
return a is not None and a != b
|
||||
|
||||
# only call update_user if something changed
|
||||
if (
|
||||
changed(create_user.enabled, user.enabled)
|
||||
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
|
||||
or changed(create_user.notes, user.notes)
|
||||
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
|
||||
):
|
||||
user = ur.update_user(
|
||||
user.id,
|
||||
enabled=create_user.enabled,
|
||||
show_on_leaderboard=create_user.show_on_leaderboard,
|
||||
tos_acceptance=create_user.tos_acceptance,
|
||||
notes=create_user.notes,
|
||||
)
|
||||
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
auth_method: str,
|
||||
|
||||
@@ -18,7 +18,7 @@ async def get_text_toxicity(
|
||||
|
||||
Args:
|
||||
msg (str): the message that we want to analyze.
|
||||
api_client (ApiClient, optional): authentification of the user of the request.
|
||||
api_client (ApiClient, optional): authentication of the user of the request.
|
||||
Defaults to Depends(deps.get_trusted_api_client).
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from oasst_backend import auth
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/discord")
|
||||
def login_discord(request: Request):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
|
||||
raise HTTPException(status_code=302, headers={"location": auth_url})
|
||||
|
||||
|
||||
@router.get("/callback/discord", response_model=protocol_schema.Token)
|
||||
async def callback_discord(
|
||||
auth_code: str,
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
# Exchange the auth code for a Discord access token
|
||||
async with session.post(
|
||||
"https://discord.com/api/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
|
||||
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": "identify",
|
||||
},
|
||||
) as token_response:
|
||||
token_response_json = await token_response.json()
|
||||
access_token = token_response_json["access_token"]
|
||||
|
||||
# Retrieve user's Discord information using access token
|
||||
async with session.get(
|
||||
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
|
||||
) as user_response:
|
||||
user_response_json = await user_response.json()
|
||||
discord_id = user_response_json["id"]
|
||||
|
||||
account: Account = auth.get_account_from_discord_id(db, discord_id)
|
||||
|
||||
if not account:
|
||||
# Discord account is not linked to an OA account
|
||||
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
|
||||
|
||||
# Discord account is valid and linked to an OA account -> create JWT
|
||||
access_token = auth.create_access_token(account)
|
||||
|
||||
return protocol_schema.Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
def get_callback_uri(request: Request):
|
||||
"""
|
||||
Gets the URI for the base callback endpoint with no provider name appended.
|
||||
"""
|
||||
# This seems ugly, not sure if there is a better way
|
||||
current_url = str(request.url)
|
||||
domain = current_url.split("/api/v1/")[0]
|
||||
redirect_uri = f"{domain}/api/v1/callback"
|
||||
return redirect_uri
|
||||
@@ -104,7 +104,8 @@ def tasks_acknowledge(
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
logger.info(f"Frontend ACK task_id={task_id}")
|
||||
logger.debug(f"{ack_request=}.")
|
||||
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
@@ -131,7 +132,7 @@ def tasks_acknowledge_failure(
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
pr.skip_task(task_id=task_id, reason=nack_request.reason)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import TrollboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{time_frame}", response_model=TrollboardStats)
|
||||
def get_trollboard(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> TrollboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_trollboard(time_frame, limit=max_count)
|
||||
@@ -191,6 +191,7 @@ def update_user(
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
@@ -198,7 +199,7 @@ def update_user(
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -17,8 +17,15 @@ def prepare_message(m: Message) -> protocol.Message:
|
||||
created_date=m.created_date,
|
||||
emojis=m.emojis or {},
|
||||
user_emojis=m.user_emojis or [],
|
||||
user_is_author=m.user_is_author,
|
||||
review_result=m.review_result,
|
||||
review_count=m.review_count,
|
||||
ranking_count=m.ranking_count,
|
||||
deleted=m.deleted,
|
||||
synthetic=m.synthetic,
|
||||
model_name=m.model_name,
|
||||
message_tree_id=m.message_tree_id,
|
||||
rank=m.rank,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +43,7 @@ def prepare_conversation_message(message: Message) -> protocol.ConversationMessa
|
||||
is_assistant=(message.role == "assistant"),
|
||||
emojis=message.emojis or {},
|
||||
user_emojis=message.user_emojis or [],
|
||||
user_is_author=message.user_is_author,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from jose import jwt
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""
|
||||
Create an encoded JSON Web Token (JWT) using the given data.
|
||||
"""
|
||||
|
||||
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
|
||||
"""
|
||||
Get the Open-Assistant Account associated with the given Discord ID.
|
||||
"""
|
||||
|
||||
account: Account = (
|
||||
db.query(Account)
|
||||
.filter(
|
||||
Account.provider == "discord",
|
||||
Account.provider_account_id == discord_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return account
|
||||
@@ -13,7 +13,10 @@ class TreeManagerConfiguration(BaseModel):
|
||||
No new initial prompt tasks are handed out to users if this
|
||||
number is reached."""
|
||||
|
||||
max_tree_depth: int = 6
|
||||
max_initial_prompt_review: int = 100
|
||||
"""Maximum number of initial prompts under review before no more initial prompt tasks will be handed out."""
|
||||
|
||||
max_tree_depth: int = 3
|
||||
"""Maximum depth of message tree."""
|
||||
|
||||
max_children_count: int = 3
|
||||
@@ -22,22 +25,39 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_prompter_replies: int = 1
|
||||
"""Number of prompter replies to collect per assistant reply."""
|
||||
|
||||
goal_tree_size: int = 15
|
||||
goal_tree_size: int = 12
|
||||
"""Total number of messages to gather per tree."""
|
||||
|
||||
random_goal_tree_size: bool = False
|
||||
"""If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size]."""
|
||||
|
||||
min_goal_tree_size: int = 5
|
||||
"""Minimum tree size for random goal sizes."""
|
||||
|
||||
num_reviews_initial_prompt: int = 3
|
||||
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""
|
||||
|
||||
num_reviews_reply: int = 3
|
||||
"""Number of peer review checks to collect per reply (other than initial_prompt)."""
|
||||
|
||||
p_full_labeling_review_prompt: float = 0.1
|
||||
auto_mod_enabled: bool = True
|
||||
"""Flag to enable/disable auto moderation."""
|
||||
|
||||
auto_mod_max_skip_reply: int = 25
|
||||
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
|
||||
of users skip replying to a message. (auto moderation)"""
|
||||
|
||||
auto_mod_red_flags: int = 4
|
||||
"""Delete messages that receive more than this number of red flags if it is a reply or
|
||||
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
|
||||
|
||||
p_full_labeling_review_prompt: float = 1.0
|
||||
"""Probability of full text-labeling (instead of mandatory only) for initial prompts."""
|
||||
|
||||
p_full_labeling_review_reply_assistant: float = 0.1
|
||||
p_full_labeling_review_reply_assistant: float = 0.5
|
||||
"""Probability of full text-labeling (instead of mandatory only) for assistant replies."""
|
||||
|
||||
p_full_labeling_review_reply_prompter: float = 0.1
|
||||
p_full_labeling_review_reply_prompter: float = 0.25
|
||||
"""Probability of full text-labeling (instead of mandatory only) for prompter replies."""
|
||||
|
||||
acceptance_threshold_initial_prompt: float = 0.6
|
||||
@@ -55,7 +75,7 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
min_active_rankings_per_lang: int = 0
|
||||
"""When the number of active ranking tasks is below this value when a tree enters a terminal
|
||||
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
|
||||
state an available trees in BACKLOG_RANKING will be activated (i.e. enters the RANKING state)."""
|
||||
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
@@ -112,13 +132,13 @@ class TreeManagerConfiguration(BaseModel):
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
lonely_children_count: int = 3
|
||||
lonely_children_count: int = 2
|
||||
"""Number of children below which parents are preferred during sampling for reply tasks."""
|
||||
|
||||
p_lonely_child_extension: float = 0.8
|
||||
p_lonely_child_extension: float = 0.75
|
||||
"""Probability to select a prompter message parent with less than lonely_children_count children."""
|
||||
|
||||
recent_tasks_span_sec: int = 3 * 60 # 3 min
|
||||
recent_tasks_span_sec: int = 5 * 60 # 5 min
|
||||
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
|
||||
|
||||
|
||||
@@ -135,6 +155,11 @@ class Settings(BaseSettings):
|
||||
AUTH_LENGTH: int = 32
|
||||
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
|
||||
AUTH_COOKIE_NAME: str = "next-auth.session-token"
|
||||
AUTH_ALGORITHM: str = "HS256"
|
||||
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
AUTH_DISCORD_CLIENT_ID: str = ""
|
||||
AUTH_DISCORD_CLIENT_SECRET: str = ""
|
||||
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
@@ -144,6 +169,9 @@ class Settings(BaseSettings):
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
DATABASE_MAX_TX_RETRY_COUNT: int = 3
|
||||
|
||||
DATABASE_POOL_SIZE = 75
|
||||
DATABASE_MAX_OVERFLOW = 20
|
||||
|
||||
RATE_LIMIT: bool = True
|
||||
MESSAGE_SIZE_LIMIT: int = 2000
|
||||
REDIS_HOST: str = "localhost"
|
||||
@@ -154,10 +182,14 @@ class Settings(BaseSettings):
|
||||
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
|
||||
)
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_ALLOW_SELF_RANKING: bool = False # allow users to rank their own messages
|
||||
DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
|
||||
True # TODO: set False after ToS acceptance UI was added to web-frontend
|
||||
)
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
@@ -214,6 +246,8 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_TASK_API_TIMES: int = 10_000
|
||||
RATE_LIMIT_TASK_API_MINUTES: int = 1
|
||||
|
||||
TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@@ -5,4 +5,10 @@ from sqlmodel import create_engine
|
||||
if settings.DATABASE_URI is None:
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI, echo=settings.DEBUG_DATABASE_ECHO, isolation_level="REPEATABLE READ")
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URI,
|
||||
echo=settings.DEBUG_DATABASE_ECHO,
|
||||
isolation_level="REPEATABLE READ",
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .api_client import ApiClient
|
||||
from .flagged_message import FlaggedMessage
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .message import Message
|
||||
from .message_embedding import MessageEmbedding
|
||||
@@ -8,6 +9,7 @@ from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .troll_stats import TrollStats
|
||||
from .user import User
|
||||
from .user_stats import UserStats, UserStatsTimeFrame
|
||||
|
||||
@@ -26,4 +28,6 @@ __all__ = [
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
"TrollStats",
|
||||
"FlaggedMessage",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import UUID
|
||||
|
||||
from oasst_backend.models.payload_column_type import payload_type
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@payload_type
|
||||
@@ -117,8 +117,10 @@ class LabelConversationReplyPayload(TaskPayload):
|
||||
|
||||
message_id: UUID
|
||||
conversation: protocol_schema.Conversation
|
||||
reply: str # deprecated
|
||||
reply_message: Optional[protocol_schema.ConversationMessage]
|
||||
reply: Optional[str] = Field(None, deprecated=True, description="deprecated")
|
||||
reply_message: Optional[protocol_schema.ConversationMessage] = Field(
|
||||
None, deprecated=True, description="deprecated"
|
||||
)
|
||||
valid_labels: list[str]
|
||||
mandatory_labels: Optional[list[str]]
|
||||
mode: Optional[protocol_schema.LabelTaskMode]
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class FlaggedMessage(SQLModel, table=True):
|
||||
__tablename__ = "flagged_message"
|
||||
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
|
||||
)
|
||||
)
|
||||
processed: bool = Field(nullable=False, index=True)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
|
||||
)
|
||||
)
|
||||
@@ -64,6 +64,7 @@ class Message(SQLModel, table=True):
|
||||
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
_user_is_author: Optional[bool] = PrivateAttr(default=None)
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
@@ -83,3 +84,7 @@ class Message(SQLModel, table=True):
|
||||
@property
|
||||
def user_emojis(self) -> str:
|
||||
return self._user_emojis
|
||||
|
||||
@property
|
||||
def user_is_author(self) -> str:
|
||||
return self._user_is_author
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -10,7 +12,7 @@ class State(str, Enum):
|
||||
"""States of the Open-Assistant message tree state machine."""
|
||||
|
||||
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
||||
"""In this state the message tree consists only of a single inital prompt root node.
|
||||
"""In this state the message tree consists only of a single initial prompt root node.
|
||||
Initial prompt labeling tasks will determine if the tree goes into `growing` or
|
||||
`aborted_low_grade` state."""
|
||||
|
||||
@@ -31,11 +33,11 @@ class State(str, Enum):
|
||||
compute the aggergated ranking scores that will appear in the dataset."""
|
||||
|
||||
READY_FOR_EXPORT = "ready_for_export"
|
||||
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
|
||||
"""The Scoring algorithm computed rankings scores for all children. The message tree can be
|
||||
exported as part of an Open-Assistant message tree dataset."""
|
||||
|
||||
SCORING_FAILED = "scoring_failed"
|
||||
"""An exception occured in the scoring algorithm."""
|
||||
"""An exception occurred in the scoring algorithm."""
|
||||
|
||||
ABORTED_LOW_GRADE = "aborted_low_grade"
|
||||
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
|
||||
@@ -46,6 +48,9 @@ class State(str, Enum):
|
||||
BACKLOG_RANKING = "backlog_ranking"
|
||||
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
|
||||
|
||||
PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting"
|
||||
"""Initial prompt has passed spam check, waiting to be drawn to grow."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -63,6 +68,7 @@ TERMINAL_STATES = (
|
||||
State.SCORING_FAILED,
|
||||
State.HALTED_BY_MODERATOR,
|
||||
State.BACKLOG_RANKING,
|
||||
State.PROMPT_LOTTERY_WAITING,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,3 +84,4 @@ class MessageTreeState(SQLModel, table=True):
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
@@ -31,6 +31,8 @@ class Task(SQLModel, table=True):
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
ack: Optional[bool] = None
|
||||
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
skipped: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
skip_reason: str = Field(nullable=True, max_length=512)
|
||||
frontend_message_id: Optional[str] = None
|
||||
message_tree_id: Optional[UUID] = None
|
||||
parent_message_id: Optional[UUID] = None
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class TrollStats(SQLModel, table=True):
|
||||
__tablename__ = "troll_stats"
|
||||
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
|
||||
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
|
||||
)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
troll_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: float = Field(nullable=True)
|
||||
humor: float = Field(nullable=True)
|
||||
toxicity: float = Field(nullable=True)
|
||||
violence: float = Field(nullable=True)
|
||||
helpfulness: float = Field(nullable=True)
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
def compute_troll_score(self) -> int:
|
||||
return (
|
||||
self.red_flags * 3
|
||||
- self.upvotes
|
||||
+ self.downvotes
|
||||
+ self.spam_prompts
|
||||
+ self.lang_mismach
|
||||
+ self.not_appropriate
|
||||
+ self.pii
|
||||
+ self.hate_speech
|
||||
+ self.sexual_content
|
||||
+ self.political_content
|
||||
)
|
||||
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
# terms of service acceptance date
|
||||
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
def to_protocol_frontend_user(self):
|
||||
return protocol.FrontEndUser(
|
||||
user_id=self.id,
|
||||
@@ -55,4 +58,19 @@ class User(SQLModel, table=True):
|
||||
streak_days=self.streak_days,
|
||||
streak_last_day_date=self.streak_last_day_date,
|
||||
last_activity_date=self.last_activity_date,
|
||||
tos_acceptance_date=self.tos_acceptance_date,
|
||||
)
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
__tablename__ = "account"
|
||||
__table_args__ = (Index("provider", "provider_account_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
user_id: UUID = Field(foreign_key="user.id")
|
||||
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
|
||||
provider_account_id: str = Field(nullable=False, max_length=128)
|
||||
|
||||
@@ -7,12 +7,14 @@ from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import FrontendUserId
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import (
|
||||
ApiClient,
|
||||
FlaggedMessage,
|
||||
Message,
|
||||
MessageEmbedding,
|
||||
MessageEmoji,
|
||||
@@ -35,7 +37,21 @@ from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
_task_type_and_reaction = (
|
||||
(
|
||||
(db_payload.PrompterReplyPayload, db_payload.AssistantReplyPayload),
|
||||
protocol_schema.EmojiCode.skip_reply,
|
||||
),
|
||||
(
|
||||
(db_payload.LabelInitialPromptPayload, db_payload.LabelConversationReplyPayload),
|
||||
protocol_schema.EmojiCode.skip_labeling,
|
||||
),
|
||||
(
|
||||
(db_payload.RankInitialPromptsPayload, db_payload.RankConversationRepliesPayload),
|
||||
protocol_schema.EmojiCode.skip_ranking,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -77,7 +93,14 @@ class PromptRepository:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
if self.user.deleted or not self.user.enabled:
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
|
||||
raise OasstError(
|
||||
"User has not accepted terms of service.",
|
||||
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
|
||||
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
)
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
@@ -90,7 +113,7 @@ class PromptRepository:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.",
|
||||
OasstErrorCode.MESSAGE_NOT_FOUND,
|
||||
HTTP_404_NOT_FOUND,
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
return message
|
||||
|
||||
@@ -134,12 +157,15 @@ class PromptRepository:
|
||||
review_result=review_result,
|
||||
)
|
||||
self.db.add(message)
|
||||
|
||||
# self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def _validate_task(
|
||||
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
|
||||
self,
|
||||
task: Task,
|
||||
*,
|
||||
task_id: Optional[UUID] = None,
|
||||
frontend_message_id: Optional[str] = None,
|
||||
check_ack: bool = True,
|
||||
) -> Task:
|
||||
if task is None:
|
||||
if task_id:
|
||||
@@ -150,7 +176,7 @@ class PromptRepository:
|
||||
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
if check_ack and not task.ack:
|
||||
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
@@ -262,6 +288,10 @@ class PromptRepository:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
logger.debug(
|
||||
f"Inserted message id={user_message.id}, tree={user_message.message_tree_id}, user_id={user_message.user_id}, "
|
||||
f"text[:100]='{user_message.text[:100]}', role='{user_message.role}', lang='{user_message.lang}'"
|
||||
)
|
||||
return user_message
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
@@ -456,7 +486,7 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
logger.debug(f"text_labels relpy: {valid_labels=}, {mandatory_labels=}")
|
||||
logger.debug(f"text_labels reply: {valid_labels=}, {mandatory_labels=}")
|
||||
|
||||
if valid_labels:
|
||||
if not all([label in valid_labels for label in text_labels.labels.keys()]):
|
||||
@@ -628,12 +658,6 @@ class PromptRepository:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return self._add_user_emojis_all(qry)
|
||||
|
||||
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
|
||||
qry = self.db.query(MessageTreeState).filter(
|
||||
MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT
|
||||
)
|
||||
return qry.all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
"""
|
||||
Fetch a conversation with multiple possible replies to it.
|
||||
@@ -675,7 +699,7 @@ class PromptRepository:
|
||||
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
|
||||
@@ -848,6 +872,7 @@ class PromptRepository:
|
||||
user_emojis = x["user_emojis"]
|
||||
if user_emojis:
|
||||
m._user_emojis = user_emojis.split(",")
|
||||
m._user_is_author = self.user_id and self.user_id == m.user_id
|
||||
messages.append(m)
|
||||
return messages
|
||||
|
||||
@@ -874,7 +899,7 @@ class PromptRepository:
|
||||
|
||||
if api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
|
||||
|
||||
qry = self.db.query(Message)
|
||||
if user_id:
|
||||
@@ -995,7 +1020,31 @@ WHERE message.id = cc.id;
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
|
||||
@managed_tx_method()
|
||||
def skip_task(self, task_id: UUID, reason: str):
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
task = self.task_repository.fetch_task_by_id(task_id)
|
||||
self._validate_task(task, check_ack=False)
|
||||
|
||||
if not task.collective:
|
||||
task.skipped = True
|
||||
task.skip_reason = reason
|
||||
self.db.add(task)
|
||||
|
||||
def handle_cancel_emoji(task_payload: db_payload.TaskPayload) -> Message | None:
|
||||
for types, emoji in _task_type_and_reaction:
|
||||
for t in types:
|
||||
if isinstance(task_payload, t):
|
||||
return self.handle_message_emoji(task.parent_message_id, protocol_schema.EmojiOp.add, emoji)
|
||||
return None
|
||||
|
||||
task_payload: db_payload.TaskPayload = task.payload.payload
|
||||
handle_cancel_emoji(task_payload)
|
||||
|
||||
def handle_message_emoji(
|
||||
self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema.EmojiCode
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
message = self.fetch_message(message_id)
|
||||
@@ -1038,6 +1087,22 @@ WHERE message.id = cc.id;
|
||||
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
|
||||
)
|
||||
|
||||
if message.user_id == self.user_id and emoji in (
|
||||
protocol_schema.EmojiCode.thumbs_up,
|
||||
protocol_schema.EmojiCode.thumbs_down,
|
||||
):
|
||||
logger.debug(f"Ignoring add emoji op for user's own message ({emoji=})")
|
||||
return message
|
||||
|
||||
# Add to flagged_message table if the red flag emoji is applied
|
||||
if emoji == protocol_schema.EmojiCode.red_flag:
|
||||
flagged_message = FlaggedMessage(message_id=message_id, processed=False, created_date=utcnow())
|
||||
insert_stmt = pg.insert(FlaggedMessage).values(**flagged_message.dict())
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||
constraint="flagged_message_pkey", set_=flagged_message.dict()
|
||||
)
|
||||
self.db.execute(upsert_stmt)
|
||||
|
||||
# insert emoji record & increment count
|
||||
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
|
||||
self.db.add(message_emoji)
|
||||
@@ -1073,3 +1138,23 @@ WHERE message.id = cc.id;
|
||||
self.db.add(message)
|
||||
self.db.flush()
|
||||
return message
|
||||
|
||||
def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessage]:
|
||||
qry = self.db.query(FlaggedMessage)
|
||||
if max_count is not None:
|
||||
qry = qry.limit(max_count)
|
||||
|
||||
return qry.all()
|
||||
|
||||
def process_flagged_message(self, message_id: UUID) -> FlaggedMessage:
|
||||
|
||||
message = self.db.query(FlaggedMessage).get(message_id)
|
||||
|
||||
if not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
|
||||
|
||||
message.processed = True
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, or_
|
||||
from oasst_shared.utils import utcnow
|
||||
from sqlmodel import Session, delete, false, func, or_
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -24,6 +26,13 @@ def validate_frontend_message_id(message_id: str) -> None:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
|
||||
def delete_expired_tasks(session: Session) -> int:
|
||||
stm = delete(Task).where(Task.expiry_date < utcnow(), Task.done == false())
|
||||
result = session.exec(stm)
|
||||
logger.info(f"Deleted {result.rowcount} expired tasks.")
|
||||
return result.rowcount
|
||||
|
||||
|
||||
class TaskRepository:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -100,8 +109,6 @@ class TaskRepository:
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
@@ -112,8 +119,6 @@ class TaskRepository:
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
reply_message=task.reply_message,
|
||||
valid_labels=task.valid_labels,
|
||||
mandatory_labels=task.mandatory_labels,
|
||||
mode=task.mode,
|
||||
@@ -122,12 +127,18 @@ class TaskRepository:
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
if not collective and settings.TASK_VALIDITY_MINUTES > 0:
|
||||
expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES)
|
||||
else:
|
||||
expiry_date = None
|
||||
|
||||
task_model = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
expiry_date=expiry_date,
|
||||
)
|
||||
assert task_model.id == task.id
|
||||
return task_model
|
||||
@@ -166,26 +177,11 @@ class TaskRepository:
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
raise OasstError("Already closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def insert_task(
|
||||
self,
|
||||
@@ -194,6 +190,7 @@ class TaskRepository:
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
expiry_date: datetime = None,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
@@ -205,6 +202,7 @@ class TaskRepository:
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
expiry_date=expiry_date,
|
||||
)
|
||||
logger.debug(f"inserting {task=}")
|
||||
self.db.add(task)
|
||||
@@ -224,14 +222,19 @@ class TaskRepository:
|
||||
return task
|
||||
|
||||
def fetch_recent_reply_tasks(
|
||||
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
|
||||
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
|
||||
) -> list[Task]:
|
||||
qry = self.db.query(Task).filter(
|
||||
func.age(Task.created_date) < max_age,
|
||||
func.age(func.current_timestamp(), Task.created_date) < max_age,
|
||||
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
|
||||
)
|
||||
if done is not None:
|
||||
qry = qry.filter(Task.done == done)
|
||||
if skipped is not None:
|
||||
qry = qry.filter(Task.skipped == skipped)
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
return qry.all()
|
||||
|
||||
def delete_expired(self) -> int:
|
||||
return delete_expired_tasks(self.db)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
@@ -9,23 +7,34 @@ from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import (
|
||||
prepare_conversation,
|
||||
prepare_conversation_message,
|
||||
prepare_conversation_message_list,
|
||||
)
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.models import (
|
||||
Message,
|
||||
MessageEmoji,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
User,
|
||||
message_tree_state,
|
||||
)
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils import tree_export
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.database_utils import (
|
||||
CommitMode,
|
||||
async_managed_tx_method,
|
||||
managed_tx_function,
|
||||
managed_tx_method,
|
||||
)
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, or_, text, update
|
||||
from oasst_shared.utils import utcnow
|
||||
from sqlalchemy.sql.functions import coalesce
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -43,6 +52,19 @@ class TaskRole(Enum):
|
||||
ASSISTANT = 2
|
||||
|
||||
|
||||
class TreeStateStats(pydantic.BaseModel):
|
||||
initial_prompt_review: int
|
||||
growing: int
|
||||
ranking: int
|
||||
ready_for_scoring: int
|
||||
scoring_failed: int
|
||||
ready_for_export: int
|
||||
aborted_low_grade: int
|
||||
halted_by_moderator: int
|
||||
backlog_ranking: int
|
||||
prompt_lottery_waiting: int
|
||||
|
||||
|
||||
class ActiveTreeSizeRow(pydantic.BaseModel):
|
||||
message_tree_id: UUID
|
||||
goal_tree_size: int
|
||||
@@ -157,7 +179,7 @@ class TreeManager:
|
||||
|
||||
def _determine_task_availability_internal(
|
||||
self,
|
||||
num_active_trees: int,
|
||||
num_missing_prompts: int,
|
||||
extendible_parents: list[ExtendibleParentRow],
|
||||
prompts_need_review: list[Message],
|
||||
replies_need_review: list[Message],
|
||||
@@ -165,7 +187,6 @@ class TreeManager:
|
||||
) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
|
||||
|
||||
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
@@ -198,6 +219,117 @@ class TreeManager:
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
|
||||
# Under high load the DB runs into deadlocks when many trees are released
|
||||
# simultaneously (happens whens the max_active_trees setting is increased).
|
||||
# To reduce the chance of write conflicts during updates of rows in the
|
||||
# message_tree_state table we limit the number of trees that are activated
|
||||
# per _prompt_lottery() call to max_activate.
|
||||
activated = 0
|
||||
|
||||
while True:
|
||||
|
||||
stats = self.tree_counts_by_state_stats(lang=lang, only_active=True)
|
||||
|
||||
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
|
||||
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
|
||||
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
|
||||
|
||||
if num_missing_growing == 0 or activated >= max_activate:
|
||||
return num_missing_growing + remaining_prompt_review
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
def activate_one(db: Session) -> int:
|
||||
# select among distinct users
|
||||
authors_qry = (
|
||||
db.query(Message.user_id)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.distinct(Message.user_id)
|
||||
)
|
||||
|
||||
author_ids = authors_qry.all()
|
||||
if len(author_ids) == 0:
|
||||
logger.info(
|
||||
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
|
||||
)
|
||||
return False
|
||||
|
||||
# first select an authour
|
||||
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
|
||||
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
|
||||
|
||||
# select random prompt of author
|
||||
qry = (
|
||||
db.query(MessageTreeState, Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.user_id == prompt_author_id,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.limit(100)
|
||||
)
|
||||
|
||||
prompt_candidates = qry.all()
|
||||
if len(prompt_candidates) == 0:
|
||||
logger.warning("No prompt candidates of selected author found.")
|
||||
return False
|
||||
|
||||
winner_prompt = random.choice(prompt_candidates)
|
||||
message: Message = winner_prompt.Message
|
||||
logger.info(f"Prompt lottery winner: {message.id=}")
|
||||
|
||||
mts: MessageTreeState = winner_prompt.MessageTreeState
|
||||
mts.state = message_tree_state.State.GROWING
|
||||
mts.active = True
|
||||
db.add(mts)
|
||||
|
||||
if mts.won_prompt_lottery_date is None:
|
||||
mts.won_prompt_lottery_date = utcnow()
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
return True
|
||||
|
||||
if not activate_one():
|
||||
return num_missing_growing + remaining_prompt_review
|
||||
|
||||
activated += 1
|
||||
|
||||
def _auto_moderation(self, lang: str) -> None:
|
||||
if not self.cfg.auto_mod_enabled:
|
||||
return
|
||||
|
||||
bad_messages = self.query_moderation_bad_messages(lang=lang)
|
||||
for m in bad_messages:
|
||||
num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag)
|
||||
|
||||
if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags:
|
||||
if m.parent_id is None:
|
||||
logger.warning(
|
||||
f"[AUTO MOD] Halting tree {m.message_tree_id}, initial prompt got too many red flags ({m.emojis})."
|
||||
)
|
||||
self.enter_low_grade_state(m.message_tree_id)
|
||||
else:
|
||||
logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).")
|
||||
self.pr.mark_messages_deleted(m.id, recursive=True)
|
||||
|
||||
num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply)
|
||||
if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply:
|
||||
logger.warning(
|
||||
f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})."
|
||||
)
|
||||
self.halt_tree(m.id, halt=True)
|
||||
|
||||
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
@@ -205,14 +337,15 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
self._auto_moderation(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1)
|
||||
extendible_parents, _ = self.query_extendible_parents(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
@@ -242,7 +375,9 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task request without lang tag received, assuming 'en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
self._auto_moderation(lang=lang)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)
|
||||
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
|
||||
@@ -260,7 +395,7 @@ class TreeManager:
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
@@ -272,7 +407,7 @@ class TreeManager:
|
||||
)
|
||||
else:
|
||||
task_count_by_type = self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
@@ -386,7 +521,6 @@ class TreeManager:
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
@@ -412,7 +546,6 @@ class TreeManager:
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
reply_message=prepare_conversation_message(message),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
@@ -425,11 +558,6 @@ class TreeManager:
|
||||
|
||||
case TaskType.REPLY:
|
||||
|
||||
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
|
||||
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
|
||||
)
|
||||
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
|
||||
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
@@ -440,24 +568,17 @@ class TreeManager:
|
||||
random_parent: ExtendibleParentRow = None
|
||||
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
|
||||
# check if we have extendible prompter parents with a small number of replies
|
||||
|
||||
lonely_children_parents = [
|
||||
p
|
||||
for p in extendible_parents
|
||||
if 0 < p.active_children_count < self.cfg.lonely_children_count
|
||||
and p.parent_role == "prompter"
|
||||
and p.parent_id not in recent_reply_task_parents
|
||||
]
|
||||
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
|
||||
random_parent = random.choice(lonely_children_parents)
|
||||
|
||||
if random_parent is None:
|
||||
# try to exclude parents for which tasks were recently handed out
|
||||
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
|
||||
if len(fresh_parents) > 0:
|
||||
random_parent = random.choice(fresh_parents)
|
||||
else:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
random_parent = random.choice(extendible_parents)
|
||||
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
@@ -479,6 +600,7 @@ class TreeManager:
|
||||
case TaskType.LABEL_PROMPT:
|
||||
assert len(prompts_need_review) > 0
|
||||
message = random.choice(prompts_need_review)
|
||||
message = self.pr.fetch_message(message.id) # re-fetch message including emojis
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
label_disposition = protocol_schema.LabelTaskDisposition.quality
|
||||
@@ -495,6 +617,7 @@ class TreeManager:
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.text,
|
||||
conversation=prepare_conversation([message]),
|
||||
valid_labels=list(map(lambda x: x.value, valid_labels)),
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
|
||||
mode=label_mode,
|
||||
@@ -519,7 +642,8 @@ class TreeManager:
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
logger.info(f"Generated task (type={task.type}, id={task.id})")
|
||||
logger.debug(f"Generated {task=}.")
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
@@ -530,8 +654,9 @@ class TreeManager:
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
f"Frontend reports text reply to message_id={interaction.message_id} by user={interaction.user}."
|
||||
)
|
||||
logger.debug(f"with {interaction.text=}")
|
||||
# here we store the text reply in the database
|
||||
message = pr.store_text_reply(
|
||||
text=interaction.text,
|
||||
@@ -578,23 +703,26 @@ class TreeManager:
|
||||
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
f"Frontend reports rating of message_id={interaction.message_id} by user={interaction.user}."
|
||||
)
|
||||
logger.debug(f"with {interaction.rating=}")
|
||||
|
||||
pr.store_rating(interaction)
|
||||
|
||||
case protocol_schema.MessageRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
f"Frontend reports ranking of message_id={interaction.message_id} by user={interaction.user}."
|
||||
)
|
||||
logger.debug(f"with {interaction.ranking=}")
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
|
||||
f"Frontend reports labels of message_id={interaction.message_id} by user={interaction.user}."
|
||||
)
|
||||
logger.debug(f"with {interaction.labels=}")
|
||||
|
||||
_, task, msg = pr.store_text_labels(interaction)
|
||||
|
||||
@@ -615,7 +743,7 @@ class TreeManager:
|
||||
)
|
||||
else:
|
||||
self.enter_low_grade_state(msg.message_tree_id)
|
||||
self.check_condition_for_growing_state(msg.message_tree_id)
|
||||
self.check_condition_for_prompt_lottery(msg.message_tree_id)
|
||||
elif msg.review_count >= self.cfg.num_reviews_reply:
|
||||
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
||||
msg.review_result = True
|
||||
@@ -649,10 +777,12 @@ class TreeManager:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
|
||||
if self.cfg.min_active_rankings_per_lang > 0:
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang, user_filter=False)
|
||||
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
else:
|
||||
if mts.state == message_tree_state.State.GROWING and mts.won_prompt_lottery_date is None:
|
||||
mts.won_prompt_lottery_date = utcnow()
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
|
||||
@@ -660,8 +790,8 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
def check_condition_for_prompt_lottery(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_prompt_lottery({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
|
||||
@@ -674,7 +804,7 @@ class TreeManager:
|
||||
logger.debug(f"False {initial_prompt.review_result=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
self._enter_state(mts, message_tree_state.State.PROMPT_LOTTERY_WAITING)
|
||||
return True
|
||||
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
@@ -746,7 +876,7 @@ class TreeManager:
|
||||
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
|
||||
continue
|
||||
|
||||
# keep only elements in commond set
|
||||
# keep only elements in common set
|
||||
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
|
||||
assert all(len(x) == len(common_set) for x in ordered_ids_list)
|
||||
|
||||
@@ -824,6 +954,14 @@ class TreeManager:
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.outerjoin(
|
||||
MessageEmoji,
|
||||
and_(
|
||||
Message.id == MessageEmoji.message_id,
|
||||
MessageEmoji.user_id == self.pr.user_id,
|
||||
MessageEmoji.emoji == protocol_schema.EmojiCode.skip_labeling,
|
||||
),
|
||||
)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == state,
|
||||
@@ -831,6 +969,7 @@ class TreeManager:
|
||||
not_(Message.deleted),
|
||||
Message.review_count < required_reviews,
|
||||
Message.lang == lang,
|
||||
MessageEmoji.message_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -883,14 +1022,23 @@ SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) chi
|
||||
mts.message_tree_id
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
INNER JOIN message p ON m.parent_id = p.id
|
||||
LEFT JOIN message_emoji me on
|
||||
(m.parent_id = me.message_id
|
||||
AND :skip_user_id IS NOT NULL
|
||||
AND me.user_id = :skip_user_id
|
||||
AND me.emoji = :skip_ranking)
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :ranking_state -- message tree must be in ranking state
|
||||
AND m.review_result -- must be reviewed
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND p.lang = :lang -- parent lang matches
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
AND me.message_id IS NULL -- no skip ranking emoji for user
|
||||
GROUP BY m.parent_id, m.role, mts.message_tree_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
HAVING COUNT(m.id) > 1 -- more than one child
|
||||
AND MIN(m.ranking_count) < :num_required_rankings -- not complete
|
||||
AND COUNT(m.id) FILTER (WHERE m.user_id = :rank_user_id) = 0 -- no self-ranking
|
||||
"""
|
||||
|
||||
_sql_find_incomplete_rankings_ex = f"""
|
||||
@@ -900,29 +1048,54 @@ SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
|
||||
ir.message_tree_id
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
HAVING COUNT(mr.message_id) FILTER (WHERE mr.user_id = :dupe_user_id) = 0
|
||||
"""
|
||||
|
||||
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have childern that need further rankings"""
|
||||
def query_incomplete_rankings(self, lang: str, user_filter: bool = True) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have children that need further rankings"""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
dupe_user_id = None
|
||||
skip_user_id = None
|
||||
rank_user_id = None
|
||||
if user_filter:
|
||||
if not settings.DEBUG_ALLOW_DUPLICATE_TASKS:
|
||||
dupe_user_id = self.pr.user_id
|
||||
if not settings.DEBUG_ALLOW_SELF_RANKING:
|
||||
rank_user_id = self.pr.user_id
|
||||
skip_user_id = self.pr.user_id
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_incomplete_rankings_ex),
|
||||
{
|
||||
"num_required_rankings": self.cfg.num_required_rankings,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"dupe_user_id": dupe_user_id,
|
||||
"skip_user_id": skip_user_id,
|
||||
"rank_user_id": rank_user_id,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"skip_ranking": protocol_schema.EmojiCode.skip_ranking,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_find_extendible_parents = """
|
||||
-- find all extendible parent nodes
|
||||
WITH recent_reply_tasks (parent_message_id) AS (
|
||||
-- recent incomplete tasks to exclude
|
||||
SELECT parent_message_id FROM task
|
||||
WHERE not done
|
||||
AND not skipped
|
||||
AND created_date > (CURRENT_TIMESTAMP - :recent_tasks_interval)
|
||||
AND (payload_type = 'AssistantReplyPayload' OR payload_type = 'PrompterReplyPayload')
|
||||
)
|
||||
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message_emoji me ON
|
||||
(m.id = me.message_id
|
||||
AND :skip_user_id IS NOT NULL
|
||||
AND me.user_id = :skip_user_id
|
||||
AND me.emoji = :skip_reply)
|
||||
LEFT JOIN recent_reply_tasks rrt ON m.id = rrt.parent_message_id -- recent tasks
|
||||
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
@@ -930,6 +1103,8 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND m.lang = :lang -- parent matches lang
|
||||
AND me.message_id IS NULL -- no skip reply emoji for user
|
||||
AND rrt.parent_message_id IS NULL -- no recent reply task found
|
||||
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
@@ -950,6 +1125,9 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"num_prompter_replies": self.cfg.num_prompter_replies,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"skip_user_id": self.pr.user_id,
|
||||
"skip_reply": protocol_schema.EmojiCode.skip_reply,
|
||||
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -988,6 +1166,9 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
"num_prompter_replies": self.cfg.num_prompter_replies,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"skip_user_id": self.pr.user_id,
|
||||
"skip_reply": protocol_schema.EmojiCode.skip_reply,
|
||||
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
@@ -1079,7 +1260,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def ensure_tree_states(self) -> None:
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
"""Add message tree state rows for all root nodes (initial prompt messages)."""
|
||||
|
||||
missing_tree_ids = self.query_misssing_tree_states()
|
||||
for id in missing_tree_ids:
|
||||
@@ -1101,7 +1282,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state."
|
||||
)
|
||||
for t in prompt_review_trees:
|
||||
self.check_condition_for_growing_state(t.message_tree_id)
|
||||
self.check_condition_for_prompt_lottery(t.message_tree_id)
|
||||
|
||||
growing_trees: list[MessageTreeState] = (
|
||||
self.db.query(MessageTreeState)
|
||||
@@ -1129,8 +1310,23 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
for t in ranking_trees:
|
||||
self.check_condition_for_scoring_state(t.message_tree_id)
|
||||
|
||||
def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int:
|
||||
"""Count all active trees (optionally exclude those in ranking state)."""
|
||||
def query_num_growing_trees(self, lang: str) -> int:
|
||||
"""Count all active trees in growing state."""
|
||||
query = (
|
||||
self.db.query(func.count(MessageTreeState.message_tree_id))
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
return query.scalar()
|
||||
|
||||
def query_num_active_trees(
|
||||
self, lang: str, exclude_ranking: bool = True, exclude_prompt_review: bool = True
|
||||
) -> int:
|
||||
"""Count all active trees (optionally exclude those in ranking and initial prompt review states)."""
|
||||
query = (
|
||||
self.db.query(func.count(MessageTreeState.message_tree_id))
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
@@ -1138,6 +1334,8 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
)
|
||||
if exclude_ranking:
|
||||
query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING)
|
||||
if exclude_prompt_review:
|
||||
query = query.filter(MessageTreeState.state != message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
@@ -1149,6 +1347,37 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
)
|
||||
return qry.all()
|
||||
|
||||
def query_moderation_bad_messages(self, lang: str) -> list[Message]:
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
or_(
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
),
|
||||
or_(
|
||||
Message.parent_id.is_(None),
|
||||
Message.review_result,
|
||||
and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply),
|
||||
),
|
||||
not_(Message.deleted),
|
||||
or_(
|
||||
coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0)
|
||||
>= self.cfg.auto_mod_red_flags,
|
||||
coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0)
|
||||
>= self.cfg.auto_mod_max_skip_reply,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if lang is not None:
|
||||
qry = qry.filter(Message.lang == lang)
|
||||
|
||||
return qry.all()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_tree_state(
|
||||
self,
|
||||
@@ -1176,22 +1405,54 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
*,
|
||||
goal_tree_size: int = None,
|
||||
) -> MessageTreeState:
|
||||
if goal_tree_size is None:
|
||||
if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size:
|
||||
goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
|
||||
else:
|
||||
goal_tree_size = self.cfg.goal_tree_size
|
||||
return self._insert_tree_state(
|
||||
root_message_id=root_message_id,
|
||||
goal_tree_size=self.cfg.goal_tree_size,
|
||||
goal_tree_size=goal_tree_size,
|
||||
max_depth=self.cfg.max_tree_depth,
|
||||
max_children_count=self.cfg.max_children_count,
|
||||
state=state,
|
||||
active=True,
|
||||
)
|
||||
|
||||
def tree_counts_by_state(self) -> dict[str, int]:
|
||||
qry = self.db.query(
|
||||
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
|
||||
).group_by(MessageTreeState.state)
|
||||
def tree_counts_by_state(self, lang: str = None, only_active: bool = False) -> dict[str, int]:
|
||||
qry = self.db.query(MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count"))
|
||||
|
||||
if lang is not None:
|
||||
qry = (
|
||||
qry.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(Message.lang == lang)
|
||||
)
|
||||
if only_active:
|
||||
qry = qry.filter(MessageTreeState.active)
|
||||
|
||||
qry = qry.group_by(MessageTreeState.state)
|
||||
return {x["state"]: x["count"] for x in qry}
|
||||
|
||||
def tree_counts_by_state_stats(self, lang: str = None, only_active: bool = False) -> TreeStateStats:
|
||||
count_by_state = self.tree_counts_by_state(lang=lang, only_active=only_active)
|
||||
r = TreeStateStats(
|
||||
initial_prompt_review=count_by_state.get(message_tree_state.State.INITIAL_PROMPT_REVIEW) or 0,
|
||||
growing=count_by_state.get(message_tree_state.State.GROWING) or 0,
|
||||
ranking=count_by_state.get(message_tree_state.State.RANKING) or 0,
|
||||
ready_for_scoring=count_by_state.get(message_tree_state.State.READY_FOR_SCORING) or 0,
|
||||
ready_for_export=count_by_state.get(message_tree_state.State.READY_FOR_EXPORT) or 0,
|
||||
scoring_failed=count_by_state.get(message_tree_state.State.SCORING_FAILED) or 0,
|
||||
halted_by_moderator=count_by_state.get(message_tree_state.State.HALTED_BY_MODERATOR) or 0,
|
||||
backlog_ranking=count_by_state.get(message_tree_state.State.BACKLOG_RANKING) or 0,
|
||||
prompt_lottery_waiting=count_by_state.get(message_tree_state.State.PROMPT_LOTTERY_WAITING) or 0,
|
||||
aborted_low_grade=count_by_state.get(message_tree_state.State.ABORTED_LOW_GRADE) or 0,
|
||||
)
|
||||
return r
|
||||
|
||||
def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
|
||||
qry = (
|
||||
self.db.query(
|
||||
@@ -1274,9 +1535,32 @@ DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id
|
||||
DELETE FROM task t WHERE t.parent_message_id = :message_id;
|
||||
DELETE FROM message WHERE id = :message_id;
|
||||
"""
|
||||
parent_id = self.pr.fetch_message(message_id=message_id).parent_id
|
||||
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
|
||||
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
|
||||
|
||||
sql_update_ranking_counts = """
|
||||
WITH r AS (
|
||||
-- find ranking results and count per child
|
||||
SELECT c.id,
|
||||
count(*) FILTER (
|
||||
WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar)
|
||||
) AS ranking_count
|
||||
FROM message c
|
||||
LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload'
|
||||
AND mr.message_id = c.parent_id
|
||||
WHERE c.parent_id = :parent_id
|
||||
GROUP BY c.id
|
||||
)
|
||||
UPDATE message m SET ranking_count = r.ranking_count
|
||||
FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count;
|
||||
"""
|
||||
|
||||
if parent_id is not None:
|
||||
# update ranking counts of remaining children
|
||||
r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id})
|
||||
logger.debug(f"ranking_count updated for {r.rowcount} rows.")
|
||||
|
||||
def purge_message_tree(self, message_tree_id: UUID) -> None:
|
||||
sql_purge_message_tree = """
|
||||
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
@@ -1292,11 +1576,17 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
def _reactivate_tree(self, mts: MessageTreeState):
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
if mts.state == message_tree_state.State.PROMPT_LOTTERY_WAITING:
|
||||
return
|
||||
|
||||
tree_id = mts.message_tree_id
|
||||
if self.check_condition_for_growing_state(tree_id):
|
||||
if mts.won_prompt_lottery_date is not None:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
if self.check_condition_for_ranking_state(tree_id):
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
else:
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_prompt_lottery(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(
|
||||
@@ -1312,7 +1602,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
total_messages = sum(len(x) for x in replies_by_tree.values())
|
||||
logger.debug(f"found: {len(replies_by_tree)} trees; {len(prompts)} prompts; {total_messages} messages;")
|
||||
|
||||
# remove all trees based on inital prompts of the user
|
||||
# remove all trees based on initial prompts of the user
|
||||
if purge_initial_prompts:
|
||||
for p in prompts:
|
||||
self.purge_message_tree(p.message_tree_id)
|
||||
@@ -1350,7 +1640,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
logger.debug(f"purging message: {m.id}")
|
||||
self._purge_message_internal(m.id)
|
||||
|
||||
# update childern counts
|
||||
# update children counts
|
||||
self.pr.update_children_counts(m.message_tree_id)
|
||||
|
||||
# reactivate tree
|
||||
@@ -1379,44 +1669,6 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
if ban:
|
||||
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
|
||||
|
||||
def export_trees_to_file(
|
||||
self,
|
||||
message_tree_ids: list[str],
|
||||
file=None,
|
||||
reviewed: bool = True,
|
||||
include_deleted: bool = False,
|
||||
use_compression: bool = False,
|
||||
) -> None:
|
||||
trees_to_export: List[tree_export.ExportMessageTree] = []
|
||||
|
||||
for message_tree_id in message_tree_ids:
|
||||
messages: List[Message] = self.pr.fetch_message_tree(message_tree_id, reviewed, include_deleted)
|
||||
trees_to_export.append(tree_export.build_export_tree(message_tree_id, messages))
|
||||
|
||||
if file:
|
||||
tree_export.write_trees_to_file(file, trees_to_export, use_compression)
|
||||
else:
|
||||
sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2))
|
||||
|
||||
def export_all_ready_trees(
|
||||
self, file: str, reviewed: bool = True, include_deleted: bool = False, use_compression: bool = False
|
||||
) -> None:
|
||||
message_tree_states: MessageTreeState = self.pr.fetch_message_trees_ready_for_export()
|
||||
message_tree_ids = [ms.message_tree_id for ms in message_tree_states]
|
||||
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
|
||||
|
||||
def export_all_user_trees(
|
||||
self,
|
||||
user_id: str,
|
||||
file: str,
|
||||
reviewed: bool = True,
|
||||
include_deleted: bool = False,
|
||||
use_compression: bool = False,
|
||||
) -> None:
|
||||
messages = self.pr.fetch_user_message_trees(UUID(user_id))
|
||||
message_tree_ids = [ms.message_tree_id for ms in messages]
|
||||
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState).filter(
|
||||
@@ -1454,7 +1706,8 @@ if __name__ == "__main__":
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
# dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
dummy_user = protocol_schema.User(id="1234", display_name="bulb", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
cfg = TreeManagerConfiguration()
|
||||
@@ -1469,14 +1722,16 @@ if __name__ == "__main__":
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
xs = tm.query_prompts_need_review(lang="en")
|
||||
print("xs", len(xs))
|
||||
for x in xs:
|
||||
print(x.id, x.emojis)
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review(lang="en"))
|
||||
# print("query_extendible_trees", tm.query_extendible_trees())
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
)
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
# print(
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
# )
|
||||
|
||||
@@ -73,7 +73,8 @@ class UserRepository:
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
) -> None:
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
) -> User:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
@@ -94,8 +95,11 @@ class UserRepository:
|
||||
user.notes = notes
|
||||
if show_on_leaderboard is not None:
|
||||
user.show_on_leaderboard = show_on_leaderboard
|
||||
if tos_acceptance:
|
||||
user.tos_acceptance_date = utcnow()
|
||||
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
@@ -143,8 +147,10 @@ class UserRepository:
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
if auth_method == "system":
|
||||
user.show_on_leaderboard = False # don't show system users, e.g. import user
|
||||
user.tos_acceptance_date = utcnow()
|
||||
self.db.add(user)
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
@@ -156,6 +162,10 @@ class UserRepository:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
if not (client_user.auth_method and client_user.id):
|
||||
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
|
||||
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
|
||||
@@ -5,19 +5,40 @@ from uuid import UUID
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models import (
|
||||
Message,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
TrollStats,
|
||||
User,
|
||||
UserStats,
|
||||
UserStatsTimeFrame,
|
||||
)
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
LabelInitialPromptPayload,
|
||||
LabelPrompterReplyPayload,
|
||||
RankingReactionPayload,
|
||||
)
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_shared.schemas.protocol import (
|
||||
EmojiCode,
|
||||
LabelTaskMode,
|
||||
LeaderboardStats,
|
||||
TextLabel,
|
||||
TrollboardStats,
|
||||
TrollScore,
|
||||
UserScore,
|
||||
)
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql.functions import coalesce
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
@@ -37,6 +58,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
|
||||
if r["TrollStats"]:
|
||||
d = r["TrollStats"].dict()
|
||||
else:
|
||||
d = {"modified_date": utcnow()}
|
||||
for k in [
|
||||
"user_id",
|
||||
"username",
|
||||
"auth_method",
|
||||
"display_name",
|
||||
"last_activity_date",
|
||||
]:
|
||||
d[k] = r[k]
|
||||
if highlighted_user_id:
|
||||
d["highlighted"] = r["user_id"] == highlighted_user_id
|
||||
return TrollScore(**d)
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
@@ -63,7 +102,7 @@ class UserStatsRepository:
|
||||
UserStats,
|
||||
)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard, User.enabled)
|
||||
.order_by(UserStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
@@ -82,7 +121,7 @@ class UserStatsRepository:
|
||||
window_size: int = 5,
|
||||
) -> LeaderboardStats | None:
|
||||
# no window for users who don't show themselves
|
||||
if not user.show_on_leaderboard:
|
||||
if not user.show_on_leaderboard or not user.enabled:
|
||||
return None
|
||||
|
||||
qry = self.session.query(UserStats).filter(UserStats.user_id == user.id, UserStats.time_frame == time_frame)
|
||||
@@ -105,7 +144,7 @@ class UserStatsRepository:
|
||||
UserStats,
|
||||
)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
|
||||
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard, User.enabled)
|
||||
.where(UserStats.rank >= min_rank, UserStats.rank <= max_rank)
|
||||
.order_by(UserStats.rank)
|
||||
)
|
||||
@@ -119,7 +158,16 @@ class UserStatsRepository:
|
||||
|
||||
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
|
||||
qry = (
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
self.session.query(
|
||||
User.id.label("user_id"),
|
||||
User.username,
|
||||
User.auth_method,
|
||||
User.display_name,
|
||||
User.streak_days,
|
||||
User.streak_last_day_date,
|
||||
User.last_activity_date,
|
||||
UserStats,
|
||||
)
|
||||
.outerjoin(UserStats, User.id == UserStats.user_id)
|
||||
.filter(User.id == user_id)
|
||||
)
|
||||
@@ -133,6 +181,38 @@ class UserStatsRepository:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def get_trollboard(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
limit: int = 100,
|
||||
highlighted_user_id: Optional[UUID] = None,
|
||||
) -> TrollboardStats:
|
||||
"""
|
||||
Get trollboard stats for the specified time frame
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.session.query(
|
||||
User.id.label("user_id"),
|
||||
User.username,
|
||||
User.auth_method,
|
||||
User.display_name,
|
||||
User.last_activity_date,
|
||||
TrollStats,
|
||||
)
|
||||
.join(TrollStats, User.id == TrollStats.user_id)
|
||||
.filter(TrollStats.time_frame == time_frame.value)
|
||||
.order_by(TrollStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
|
||||
if len(trollboard) > 0:
|
||||
last_update = max(x.modified_date for x in trollboard)
|
||||
else:
|
||||
last_update = utcnow()
|
||||
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
@@ -248,9 +328,9 @@ class UserStatsRepository:
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
if mode == LabelTaskMode.simple:
|
||||
s.labels_simple = count
|
||||
elif mode == "full":
|
||||
elif mode == LabelTaskMode.full:
|
||||
s.labels_full = count
|
||||
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
@@ -259,9 +339,20 @@ class UserStatsRepository:
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == "simple":
|
||||
if mode == LabelTaskMode.simple:
|
||||
s.labels_simple += count
|
||||
elif mode == "full":
|
||||
elif mode == LabelTaskMode.full:
|
||||
s.labels_full += count
|
||||
|
||||
qry = self.query_labels_by_mode_per_user(
|
||||
payload_type=LabelInitialPromptPayload.__name__, reference_time=base_date
|
||||
)
|
||||
for r in qry:
|
||||
uid, mode, count = r
|
||||
s = get_stats(uid)
|
||||
if mode == LabelTaskMode.simple:
|
||||
s.labels_simple += count
|
||||
elif mode == LabelTaskMode.full:
|
||||
s.labels_full += count
|
||||
|
||||
qry = self.query_rankings_per_user(reference_time=base_date)
|
||||
@@ -292,10 +383,145 @@ class UserStatsRepository:
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_ranks(time_frame=time_frame)
|
||||
self.update_leader_ranks(time_frame=time_frame)
|
||||
|
||||
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
|
||||
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(Message.user_id, func.count().label("spam_prompts"))
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
|
||||
"lang_mismach"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
|
||||
"not_appropriate"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
|
||||
"sexual_content"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
|
||||
"political_content"
|
||||
),
|
||||
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
|
||||
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
|
||||
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
|
||||
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
|
||||
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
|
||||
)
|
||||
.select_from(TextLabels)
|
||||
.join(Message, TextLabels.message_id == Message.id)
|
||||
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
time_frame_key = time_frame.value
|
||||
|
||||
stats_by_user: dict[UUID, TrollStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
def get_stats(id: UUID) -> TrollStats:
|
||||
us = stats_by_user.get(id)
|
||||
if not us:
|
||||
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
|
||||
stats_by_user[id] = us
|
||||
return us
|
||||
|
||||
# emoji counts of user's messages
|
||||
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
s.upvotes = r["up"]
|
||||
s.downvotes = r["down"]
|
||||
s.red_flags = r["flag"]
|
||||
|
||||
# num spam prompts
|
||||
qry = self.query_spam_prompts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
s = get_stats(uid).spam_prompts = count
|
||||
|
||||
label_field_names = (
|
||||
"quality",
|
||||
"humor",
|
||||
"toxicity",
|
||||
"violence",
|
||||
"helpfulness",
|
||||
"spam",
|
||||
"lang_mismach",
|
||||
"not_appropriate",
|
||||
"pii",
|
||||
"hate_speech",
|
||||
"sexual_content",
|
||||
"political_content",
|
||||
)
|
||||
|
||||
# label counts / mean values
|
||||
qry = self.query_labels_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
for fn in label_field_names:
|
||||
setattr(s, fn, r[fn])
|
||||
|
||||
# delete all existing stast for time frame
|
||||
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.troll_score = v.compute_troll_score()
|
||||
|
||||
# insert user objects
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_troll_ranks(time_frame=time_frame)
|
||||
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
"""
|
||||
Update user_stats ranks. The persisted rank values allow to
|
||||
quickly the rank of a single user and to query nearby users.
|
||||
@@ -321,7 +547,7 @@ FROM
|
||||
ORDER BY leader_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM user_stats us2
|
||||
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard
|
||||
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard AND u.enabled
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
us.user_id = r.user_id
|
||||
@@ -329,10 +555,41 @@ WHERE
|
||||
r = self.session.execute(
|
||||
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
|
||||
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
sql_update_troll_rank = """
|
||||
-- update rank
|
||||
UPDATE troll_stats ts
|
||||
SET "rank" = r."rank"
|
||||
FROM
|
||||
(SELECT
|
||||
ROW_NUMBER () OVER(
|
||||
PARTITION BY time_frame
|
||||
ORDER BY troll_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM troll_stats ts2
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
ts.user_id = r.user_id
|
||||
AND ts.time_frame = r.time_frame;"""
|
||||
r = self.session.execute(
|
||||
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
reference_time: Optional[datetime] = None,
|
||||
leader_stats: bool = True,
|
||||
troll_stats: bool = True,
|
||||
):
|
||||
if leader_stats:
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
if troll_stats:
|
||||
self._update_troll_stats_internal(time_frame, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
|
||||
@@ -66,7 +66,7 @@ def get_winner(pairs):
|
||||
|
||||
def get_ranking(pairs):
|
||||
"""
|
||||
Abuses concordance property to get a (not necessarily unqiue) ranking.
|
||||
Abuses concordance property to get a (not necessarily unique) ranking.
|
||||
The lack of uniqueness is due to the potential existence of multiple
|
||||
equally ranked winners. We have to pick one, which is where
|
||||
the non-uniqueness comes from
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gzip
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Optional, TextIO
|
||||
from typing import Iterable, Optional, TextIO
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -17,6 +21,7 @@ class ExportMessageNode(BaseModel):
|
||||
role: str
|
||||
lang: str | None
|
||||
review_count: int | None
|
||||
review_result: bool | None
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
@@ -32,6 +37,7 @@ class ExportMessageNode(BaseModel):
|
||||
role=message.role,
|
||||
lang=message.lang,
|
||||
review_count=message.review_count,
|
||||
review_result=message.review_result if message.review_result or message.review_count > 2 else None,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
emojis=message.emojis,
|
||||
@@ -41,10 +47,13 @@ class ExportMessageNode(BaseModel):
|
||||
|
||||
class ExportMessageTree(BaseModel):
|
||||
message_tree_id: str
|
||||
tree_state: Optional[str]
|
||||
prompt: Optional[ExportMessageNode]
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
def build_export_tree(
|
||||
message_tree_id: UUID, message_tree_state: TreeState, messages: list[Message]
|
||||
) -> ExportMessageTree:
|
||||
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
|
||||
|
||||
messages_by_parent = defaultdict(list)
|
||||
@@ -59,19 +68,54 @@ def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMe
|
||||
return node
|
||||
|
||||
prompt = assign_replies(messages_by_parent[None][0])
|
||||
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)
|
||||
return ExportMessageTree(message_tree_id=str(message_tree_id), tree_state=message_tree_state, prompt=prompt)
|
||||
|
||||
|
||||
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
|
||||
|
||||
out_buff: TextIO
|
||||
if use_compression:
|
||||
out_buff = gzip.open(file, "wt", encoding="UTF-8")
|
||||
# see https://stackoverflow.com/questions/17602878/how-to-handle-both-with-open-and-sys-stdout-nicely
|
||||
@contextlib.contextmanager
|
||||
def smart_open(filename: str = None) -> TextIO:
|
||||
if filename and filename != "-":
|
||||
fh = open(filename, "wt", encoding="UTF-8")
|
||||
else:
|
||||
out_buff = open(file, "wt", encoding="UTF-8")
|
||||
fh = sys.stdout
|
||||
|
||||
try:
|
||||
yield fh
|
||||
finally:
|
||||
if fh is not sys.stdout:
|
||||
fh.close()
|
||||
|
||||
|
||||
def write_trees_to_file(filename: str | None, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
|
||||
out_buff: TextIO
|
||||
|
||||
if use_compression:
|
||||
if not filename:
|
||||
raise RuntimeError("File name must be specified when using compression.")
|
||||
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
|
||||
else:
|
||||
out_buff = smart_open(filename)
|
||||
|
||||
with out_buff as f:
|
||||
for tree in trees:
|
||||
file_data = jsonable_encoder(tree)
|
||||
file_data = jsonable_encoder(tree, exclude_none=True)
|
||||
json.dump(file_data, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def write_messages_to_file(filename: str | None, messages: Iterable[Message], use_compression: bool = True) -> None:
|
||||
out_buff: TextIO
|
||||
|
||||
if use_compression:
|
||||
if not filename:
|
||||
raise RuntimeError("File name must be specified when using compression.")
|
||||
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
|
||||
else:
|
||||
out_buff = smart_open(filename)
|
||||
|
||||
with out_buff as f:
|
||||
for m in messages:
|
||||
export_message = ExportMessageNode.prep_message_export(m)
|
||||
file_data = jsonable_encoder(export_message, exclude_none=True)
|
||||
json.dump(file_data, f)
|
||||
f.write("\n")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
aiohttp==3.8.3
|
||||
alembic==1.8.1
|
||||
cryptography==39.0.0
|
||||
fastapi==0.88.0
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
# Collection of SQL Snippets
|
||||
|
||||
Here are find some SQL queries to inspect the current OA postgres DB.
|
||||
|
||||
# Baics Stats
|
||||
|
||||
```sql
|
||||
-- tables row counts
|
||||
(select 'user' as "table", count(*) from "user") union
|
||||
(select 'task', count(*) from task) union
|
||||
(select 'message_tree_state', count(*) from message_tree_state) union
|
||||
(select 'message_reaction', count(*) from message_reaction) union
|
||||
(select 'text_labels', count(*) from text_labels) union
|
||||
(select 'message', count(*) from message) union
|
||||
(select 'journal', count(*) from journal);
|
||||
```
|
||||
|
||||
# Messages
|
||||
|
||||
```sql
|
||||
-- only human by role
|
||||
select role, count(*)
|
||||
from message
|
||||
where not deleted and review_result and not synthetic
|
||||
group by role;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- language distribution of messages (incl. synthetic)
|
||||
select lang, count(*), synthetic from message where not deleted and review_result
|
||||
group by lang, synthetic;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- only human generated messages by lang
|
||||
select lang, count(*)
|
||||
from message
|
||||
where not deleted and review_result and not synthetic
|
||||
group by lang;
|
||||
```
|
||||
|
||||
## Message Trees
|
||||
|
||||
```sql
|
||||
-- total count of message trees
|
||||
select count(*) from message_tree_state;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- message tree counts by state
|
||||
select state, count(*) from message_tree_state group by state;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- count of waiting initial prompts by language
|
||||
select m.lang, count(*)
|
||||
from message_tree_state mts
|
||||
join message m on mts.message_tree_id = m.id
|
||||
where mts.state = 'prompt_lottery_waiting'
|
||||
group by m.lang;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- message trees by lang in ready_for_export or growing state
|
||||
select m.lang, mts.state, count(*)
|
||||
from message_tree_state mts
|
||||
join message m on mts.message_tree_id = m.id
|
||||
where mts.state in ('ready_for_export', 'growing')
|
||||
group by mts.state, m.lang
|
||||
order by lang, state;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- select message tree counts
|
||||
select mts.message_tree_id, count(m.id), max(m.depth), count(m.id) filter (where m.role='prompter') as prompter, count(m.id) filter (where m.role='assistant') as assistant
|
||||
from message_tree_state mts
|
||||
join message m on mts.message_tree_id = m.message_tree_id
|
||||
where mts.state='growing'
|
||||
and not m.deleted
|
||||
and m.review_result=true
|
||||
and m.lang='en'
|
||||
and mts.active
|
||||
group by mts.message_tree_id
|
||||
order by count(m.id) desc;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- show top 100 largest trees
|
||||
select mts.message_tree_id, mts.goal_tree_size, mts.state, count(m.id) as message_count
|
||||
from message_tree_state mts
|
||||
join message m on mts.message_tree_id = m.message_tree_id
|
||||
where not m.deleted and m.review_result=true
|
||||
group by mts.message_tree_id, mts.state
|
||||
order by count(m.id) desc
|
||||
limit 100;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- active trees, current & goal_size
|
||||
select mts.message_tree_id, mts.state, mts.goal_tree_size, count(m.id) AS tree_size, max(m.depth) AS max_depth
|
||||
from message_tree_state mts
|
||||
join message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
and not m.deleted
|
||||
and m.review_result
|
||||
group by mts.message_tree_id, mts.goal_tree_size;
|
||||
```
|
||||
|
||||
## Users
|
||||
|
||||
```sql
|
||||
-- count users that accepted tos
|
||||
select count(*) from "user" where tos_acceptance_date is not null;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- last 25 active users
|
||||
select u.id, u.username, u.auth_method, u.display_name, u.last_activity_date, age(current_timestamp, last_activity_date) from "user" u WHERE u.last_activity_date is not null order by u.last_activity_date desc limit 25;
|
||||
|
||||
select id, display_name, username, auth_method, last_activity_date from "user" where age(last_activity_date) < interval '1 minutes' order by last_activity_date desc limit 25;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- count active users in last 5 mins
|
||||
select count(*) from "user" u where age(current_timestamp, last_activity_date) < interval '5 mins';
|
||||
|
||||
```
|
||||
|
||||
```sql
|
||||
-- total count of non-deleted messages (human + synth)
|
||||
select count(*) from message where deleted=false and review_result=true;
|
||||
```
|
||||
|
||||
```sql
|
||||
-- count max, mean message counts per tree for a given language
|
||||
with t(message_tree_id, tree_size, state) as (select mts.message_tree_id, count(m.id), mts.state
|
||||
from message_tree_state mts
|
||||
join message m on mts.message_tree_id = m.message_tree_id
|
||||
where
|
||||
not m.deleted
|
||||
and m.review_result=true
|
||||
and m.lang = 'en'
|
||||
group by mts.message_tree_id)
|
||||
select state, count(t.*) as trees, sum(t.tree_size) as total_msgs, max(t.tree_size), avg(t.tree_size) from t group by t.state;
|
||||
```
|
||||
|
||||
## Connections
|
||||
|
||||
```sql
|
||||
-- from https://dba.stackexchange.com/questions/161760/number-of-active-connections-and-remaining-connections
|
||||
select max_conn,used,res_for_super,max_conn-used-res_for_super res_for_normal
|
||||
from
|
||||
(select count(*) used from pg_stat_activity) t1,
|
||||
(select setting::int res_for_super from pg_settings where name=$$superuser_reserved_connections$$) t2,
|
||||
(select setting::int max_conn from pg_settings where name=$$max_connections$$) t3;
|
||||
```
|
||||
+1
-1
@@ -29,7 +29,7 @@ This will create a variety of aws roles and services needed for deployment.
|
||||
copilot deploy
|
||||
```
|
||||
|
||||
This will depoy the services but it won't be 100% ready for usage. Before being
|
||||
This will deploy the services but it won't be 100% ready for usage. Before being
|
||||
ready, we have to inspect the AWS Secrets manager and extract out the database
|
||||
credentials. Read those credentials then put them, and a few other secrets, in a
|
||||
`secrets.yml` file like the following:
|
||||
|
||||
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
OASST_API_KEY="1234"
|
||||
|
||||
+388
-391
@@ -9,27 +9,25 @@ import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
assistant_reply_messages,
|
||||
confirm_label_response_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
label_assistant_reply_message,
|
||||
label_initial_prompt_message,
|
||||
label_prompter_reply_message,
|
||||
initial_prompt_messages,
|
||||
label_assistant_reply_messages,
|
||||
label_prompter_reply_messages,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
prompter_reply_messages,
|
||||
rank_assistant_reply_message,
|
||||
rank_initial_prompts_message,
|
||||
rank_prompter_reply_message,
|
||||
rank_conversation_reply_messages,
|
||||
rank_initial_prompts_messages,
|
||||
rank_prompter_reply_messages,
|
||||
task_complete_embed,
|
||||
)
|
||||
from bot.settings import Settings
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("WorkPlugin")
|
||||
|
||||
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
|
||||
|
||||
|
||||
class _TaskHandler(t.Generic[_Task_contra]):
|
||||
"""Handle user interaction for a task."""
|
||||
|
||||
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
|
||||
"""Create a new `TaskHandler`.
|
||||
|
||||
Args:
|
||||
ctx (lightbulb.Context): The context of the command that started the task.
|
||||
task (_Task_contra): The task to handle.
|
||||
"""
|
||||
self.ctx = ctx
|
||||
self.task = task
|
||||
self.task_messages = self.get_task_messages(task)
|
||||
self.sent_messages: list[hikari.Message] = []
|
||||
|
||||
@staticmethod
|
||||
def get_task_messages(task: _Task_contra) -> list[str]:
|
||||
"""Get the messages to send to the user for the task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
|
||||
"""Send the task and wait for the user to accept/skip/cancel it."""
|
||||
# Send all but the last message because we need to attach buttons to the last one
|
||||
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
|
||||
for task_msg in self.task_messages[:-1]:
|
||||
if len(task_msg) > 2000:
|
||||
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
|
||||
task_msg = task_msg[:1999]
|
||||
self.sent_messages.append(await self.ctx.author.send(task_msg))
|
||||
|
||||
# Send the last message with buttons
|
||||
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
|
||||
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
|
||||
|
||||
await task_accept_view.start(last_msg)
|
||||
await task_accept_view.wait()
|
||||
|
||||
return task_accept_view.choice
|
||||
|
||||
async def handle(self) -> None:
|
||||
"""Handle the user's response to the task.
|
||||
|
||||
This method should be called after `send` has been called."""
|
||||
# Ack task to the backend
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
|
||||
|
||||
# Loop until the user's input is accepted
|
||||
while True:
|
||||
try:
|
||||
# Wait for user to send a message
|
||||
event = await self.ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
predicate=lambda e: (
|
||||
e.author_id == self.ctx.author.id
|
||||
and e.message.content is not None
|
||||
and not e.message.content.startswith(settings.prefix)
|
||||
),
|
||||
timeout=MAX_TASK_TIME,
|
||||
)
|
||||
|
||||
# Validate the message
|
||||
if event.content is None or not self.check_user_input(event.content):
|
||||
await self.ctx.author.send("Invalid input")
|
||||
continue
|
||||
|
||||
# Confirm user input
|
||||
if not (await self.confirm_user_input(event.content)):
|
||||
continue
|
||||
|
||||
# Message is valid and confirmed by user
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
next_task = await self.notify(event.content, event)
|
||||
if not isinstance(next_task, protocol_schema.TaskDone):
|
||||
raise TypeError(f"Unknown task type: {next_task!r}")
|
||||
|
||||
return
|
||||
|
||||
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||
"""Notify the backend that the user completed the task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
"""Check the user's response to the task. Returns True if the response is valid."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def cancel(self, reason: str = "not specified") -> None:
|
||||
"""Cancel the task."""
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
await oasst_api.nack_task(self.task.id, reason)
|
||||
|
||||
|
||||
_Ranking_contra = t.TypeVar(
|
||||
"_Ranking_contra",
|
||||
bound=protocol_schema.RankAssistantRepliesTask
|
||||
| protocol_schema.RankInitialPromptsTask
|
||||
| protocol_schema.RankPrompterRepliesTask
|
||||
| protocol_schema.RankConversationRepliesTask,
|
||||
contravariant=True,
|
||||
)
|
||||
|
||||
|
||||
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
|
||||
"""This should not be used directly. Use its subclasses instead."""
|
||||
|
||||
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
|
||||
task = await oasst_api.post_interaction(
|
||||
protocol_schema.MessageRanking(
|
||||
user=protocol_schema.User(
|
||||
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
|
||||
),
|
||||
ranking=[int(r) - 1 for r in content.split(",")],
|
||||
message_id=f"{self.sent_messages[0].id}",
|
||||
)
|
||||
)
|
||||
|
||||
db: Connection = self.ctx.bot.d.db
|
||||
async with db.cursor() as cursor:
|
||||
row = await (
|
||||
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
|
||||
).fetchone()
|
||||
log_channel = row[0] if row else None
|
||||
log_messages: list[hikari.Message] = []
|
||||
|
||||
if log_channel is not None:
|
||||
for message in self.task_messages[:-1]:
|
||||
msg = await self.ctx.bot.rest.create_message(log_channel, message)
|
||||
log_messages.append(msg)
|
||||
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||
return rank_assistant_reply_message(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
|
||||
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
|
||||
super().__init__(ctx, task)
|
||||
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||
return rank_initial_prompts_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.prompt_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||
return rank_prompter_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||
return rank_conversation_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||
return initial_prompt_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
return prompter_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||
return assistant_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
|
||||
|
||||
|
||||
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
user_labels = content.split(",")
|
||||
return (
|
||||
all([l in self.task.valid_labels for l in user_labels])
|
||||
and self.task.mandatory_labels is not None
|
||||
and all([m in user_labels for m in self.task.mandatory_labels])
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||
return label_assistant_reply_messages(task)
|
||||
|
||||
|
||||
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||
return label_prompter_reply_messages(task)
|
||||
|
||||
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=str(TaskRequestType.random),
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# Only send this message if started from a server
|
||||
if ctx.guild_id is not None:
|
||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
||||
currently_working: dict[
|
||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
||||
] = ctx.bot.d.currently_working
|
||||
|
||||
async def work2(ctx: lightbulb.Context) -> None:
|
||||
"""Complete a task."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
|
||||
|
||||
# Check if the user is already working on a task
|
||||
if ctx.author.id in currently_working:
|
||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
|
||||
case False | None:
|
||||
return
|
||||
case True:
|
||||
old_msg, task_id = currently_working[ctx.author.id]
|
||||
if old_msg is not None:
|
||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
||||
map(lambda c: c, old_msg.components)
|
||||
await old_msg.delete()
|
||||
if task_id is not None:
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
task_id = currently_working[ctx.author.id]
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
|
||||
await msg.delete()
|
||||
if ctx.guild_id:
|
||||
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
currently_working[ctx.author.id] = (None, None)
|
||||
|
||||
# Create a TaskRequestType from the stringified enum value
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
# Keep sending tasks until the user doesn't want more
|
||||
try:
|
||||
await _handle_task(ctx, task_type)
|
||||
while True:
|
||||
task = await oasst_api.fetch_random_task(
|
||||
user=protocol_schema.User(
|
||||
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
|
||||
),
|
||||
)
|
||||
|
||||
# Ranking tasks
|
||||
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
|
||||
task_handler = RankAssistantRepliesHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
task_handler = RankInitialPromptHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
|
||||
task_handler = RankPrompterReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
task_handler = RankConversationReplyHandler(ctx, task)
|
||||
|
||||
# Text input tasks
|
||||
elif isinstance(task, protocol_schema.InitialPromptTask):
|
||||
task_handler = InitialPromptHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.PrompterReplyTask):
|
||||
task_handler = PrompterReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.AssistantReplyTask):
|
||||
task_handler = AssistantReplyHandler(ctx, task)
|
||||
|
||||
# Label tasks
|
||||
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
|
||||
task_handler = LabelAssistantReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
|
||||
task_handler = LabelPrompterReplyHandler(ctx, task)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown task type: {type(task)}")
|
||||
|
||||
resp = await task_handler.send()
|
||||
|
||||
match resp:
|
||||
case "accept":
|
||||
currently_working[ctx.author.id] = task.id
|
||||
await task_handler.handle()
|
||||
case "next":
|
||||
await task_handler.cancel("user skipped task")
|
||||
case "cancel":
|
||||
await task_handler.cancel("user canceled work")
|
||||
break
|
||||
case None:
|
||||
await task_handler.cancel("select timed out")
|
||||
break
|
||||
finally:
|
||||
del currently_working[ctx.author.id]
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
If they select one, present the task steps until a `task_done` task is received.
|
||||
Finally, ask the user if they want to perform another task (of the same type).
|
||||
"""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
|
||||
# Continue to complete tasks until the user doesn't want to do another
|
||||
done = False
|
||||
while not done:
|
||||
|
||||
# Loop until the user accepts a task
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
# User cancelled
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send(embed=plain_embed("Please type your response below:"))
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id
|
||||
and not (e.message.content or "").startswith(settings.prefix),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
valid, err_msg = _validate_user_input(event.content, task)
|
||||
if not valid or event.content is None:
|
||||
|
||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Confirm user input
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
content = confirm_label_response_message(event.content)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
logger.critical(f"Unknown task type: {task.type}")
|
||||
raise ValueError(f"Unknown task type: {task.type}")
|
||||
|
||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
||||
await confirm_resp_view.start(msg)
|
||||
await confirm_resp_view.wait()
|
||||
|
||||
match confirm_resp_view.choice:
|
||||
case False | None:
|
||||
continue
|
||||
case True:
|
||||
await msg.delete() # buttons are already gone
|
||||
|
||||
# Send the response to the backend
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
||||
reply = protocol_schema.MessageRanking(
|
||||
message_id=str(msg_id),
|
||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
labels = event.content.replace(" ", "").split(",")
|
||||
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
|
||||
|
||||
reply = protocol_schema.TextLabels(
|
||||
message_id=task.message_id,
|
||||
labels=labels_dict,
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {task.type}")
|
||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
||||
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
new_task = await oasst_api.post_interaction(reply)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in all the log channels that the task is complete
|
||||
conn: Connection = ctx.bot.d.db
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
||||
log_channel_ids = await cursor.fetchall()
|
||||
|
||||
channels = [
|
||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
||||
|
||||
# ask the user if they want to do another task
|
||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
||||
await another_task_view.start(msg)
|
||||
await another_task_view.wait()
|
||||
|
||||
match another_task_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg = await _send_task(ctx, task, msg)
|
||||
msg_id = str(msg.id)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
await oasst_api.ack_task(task.id, msg_id)
|
||||
return task, msg_id
|
||||
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
"""
|
||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
content = initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
content = rank_initial_prompts_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
content = rank_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_initial_prompt:
|
||||
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
|
||||
logger.debug("sending label initial prompt task")
|
||||
content = label_initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_prompter_reply:
|
||||
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
|
||||
logger.debug("sending label prompter reply task")
|
||||
content = label_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_assistant_reply:
|
||||
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
|
||||
logger.debug("sending label assistant reply task")
|
||||
content = label_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
content = prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
content = assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"unknown task type {task.type}")
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
if not msg:
|
||||
msg = await ctx.author.send(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
else:
|
||||
await msg.edit(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
# Set the choice id as the current msg id
|
||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, msg
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
||||
if content is None:
|
||||
return False, "No input provided"
|
||||
|
||||
# User message input
|
||||
if (
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.prompter_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0, "Message must be at least one character long."
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
||||
"Message must contain numbers for all replies.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
# Labels tasks
|
||||
elif task.type in (
|
||||
TaskRequestType.label_initial_prompt,
|
||||
TaskRequestType.label_prompter_reply,
|
||||
TaskRequestType.label_assistant_reply,
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.LabelInitialPromptTask
|
||||
| protocol_schema.LabelPrompterReplyTask
|
||||
| protocol_schema.LabelAssistantReplyTask,
|
||||
)
|
||||
|
||||
labels = content.replace(" ", "").split(",")
|
||||
valid_labels = set(task.valid_labels)
|
||||
return (
|
||||
set(labels).issubset(valid_labels),
|
||||
"Message must only contain labels from predefined set of labels.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"Unknown task type {task.type}")
|
||||
raise ValueError(f"Unknown task type {task.type}")
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
"""View with three buttons: accept, next, and cancel.
|
||||
|
||||
|
||||
+129
-69
@@ -1,4 +1,11 @@
|
||||
"""All user-facing messages and embeds."""
|
||||
"""All user-facing messages and embeds.
|
||||
|
||||
When sending a conversation
|
||||
- The function will return a list of strings
|
||||
- use asyncio.gather to send all messages
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _label_prompt(text: str) -> str:
|
||||
return f":question: _{text}"
|
||||
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
|
||||
return f""":question: _{text}_
|
||||
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
|
||||
Valid labels: {", ".join(valid_labels)}
|
||||
"""
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
|
||||
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
|
||||
|
||||
def _ordered_list(items: list[str]) -> str:
|
||||
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
|
||||
return "\n\n".join(_make_ordered_list(items))
|
||||
|
||||
|
||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
|
||||
|
||||
def _hint(hint: str | None) -> str:
|
||||
return f"{NL}Hint: {hint}" if hint else ""
|
||||
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
|
||||
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
messages = map(
|
||||
lambda m: f"""\
|
||||
:robot: __Assistant__:
|
||||
{m.text}
|
||||
"""
|
||||
if m.is_assistant
|
||||
else f"""\
|
||||
:person_red_hair: __User__:
|
||||
{m.text}
|
||||
""",
|
||||
conv.messages,
|
||||
)
|
||||
return list(messages)
|
||||
|
||||
|
||||
def _li(text: str) -> str:
|
||||
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
|
||||
###
|
||||
|
||||
|
||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
||||
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("INITIAL PROMPT")}
|
||||
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
|
||||
|
||||
|
||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
||||
{_hint(task.hint)}
|
||||
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||
"""
|
||||
]
|
||||
|
||||
|
||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
||||
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("RANK INITIAL PROMPTS")}
|
||||
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
|
||||
|
||||
|
||||
{_ordered_list(task.prompts)}
|
||||
{_ordered_list(task.prompt_messages)}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
||||
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
|
||||
"""
|
||||
]
|
||||
|
||||
|
||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
||||
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_h1("RANK PROMPTER REPLIES")}
|
||||
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":person_red_hair: __User__:
|
||||
{_ordered_list(task.reply_messages)}
|
||||
|
||||
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
"""
|
||||
|
||||
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_h1("RANK ASSISTANT REPLIES")}
|
||||
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":robot: __Assistant__:,
|
||||
{_ordered_list(task.reply_messages)}
|
||||
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
"""
|
||||
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":person_red_hair: __User__:
|
||||
{_ordered_list(task.reply_messages)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
|
||||
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
|
||||
|
||||
{task.prompt}
|
||||
|
||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
|
||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
"""
|
||||
|
||||
|
||||
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
|
||||
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("LABEL PROMPTER REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""{_user(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
|
||||
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("LABEL ASSISTANT REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""
|
||||
{_assistant(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
|
||||
|
||||
{_h1("PROMPTER REPLY")}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||
|
||||
:speech_balloon: _Please provide a reply to the assistant._
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_hint(task.hint)}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
"""
|
||||
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
# return [
|
||||
# message_templates.render("title.msg", "PROMPTER REPLY"),
|
||||
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
|
||||
# message_templates.render("prompter_reply_task.msg", task.hint),
|
||||
# ]
|
||||
|
||||
|
||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
||||
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||
return f"""\
|
||||
{_h1("ASSISTANT REPLY")}
|
||||
return [
|
||||
"""\
|
||||
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
"""\
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
|
||||
{_response_prompt("Please provide an assistant reply to the prompter.")}
|
||||
"""
|
||||
:speech_balloon: _Please provide a reply to the user as the assistant._
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def confirm_text_response_message(content: str) -> str:
|
||||
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
|
||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||
original_list = _make_ordered_list(items)
|
||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||
|
||||
@@ -125,9 +125,71 @@ services:
|
||||
- EMAIL_FROM=info@example.com
|
||||
- NEXTAUTH_URL=http://localhost:3000
|
||||
- DEBUG_LOGIN=true
|
||||
- NEXT_PUBLIC_CLOUDFARE_CAPTCHA_SITE_KEY=1x00000000000000000000AA
|
||||
- CLOUDFLARE_CAPTCHA_SECRET_KEY=1x0000000000000000000000000000000AA
|
||||
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA=true
|
||||
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN=true
|
||||
depends_on:
|
||||
webdb:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- "3000:3000"
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
|
||||
inference-server:
|
||||
build:
|
||||
dockerfile: docker/inference/Dockerfile.server
|
||||
context: .
|
||||
target: dev
|
||||
image: oasst-inference-server:dev
|
||||
environment:
|
||||
- "PORT=8000"
|
||||
- "REDIS_HOST=redis"
|
||||
volumes:
|
||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||
- "./inference/server:/opt/inference/server"
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-worker:
|
||||
build:
|
||||
dockerfile: docker/inference/Dockerfile.worker
|
||||
context: .
|
||||
target: dev
|
||||
image: oasst-inference-worker:dev
|
||||
environment:
|
||||
- "BACKEND_URL=ws://inference-server:8000"
|
||||
- "INFERENCE_SERVER_URL=http://inference-text-generation-server"
|
||||
volumes:
|
||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||
- "./inference/worker:/opt/inference/worker"
|
||||
depends_on:
|
||||
- inference-server
|
||||
deploy:
|
||||
replicas: 1
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-text-client:
|
||||
build:
|
||||
dockerfile: docker/inference/Dockerfile.text-client
|
||||
context: .
|
||||
image: oasst-inference-text-client
|
||||
environment:
|
||||
- "BACKEND_URL=http://inference-server:8000"
|
||||
tty: true
|
||||
stdin_open: true
|
||||
volumes:
|
||||
- "./inference/worker:/opt/inference/worker"
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- inference-server
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-text-generation-server:
|
||||
image: ghcr.io/huggingface/text-generation-inference
|
||||
environment:
|
||||
- "MODEL_ID=distilgpt2"
|
||||
profiles: ["inference"]
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG MODULE="inference"
|
||||
ARG SERVICE="server"
|
||||
|
||||
ARG APP_USER="${MODULE}-${SERVICE}"
|
||||
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
||||
|
||||
|
||||
FROM python:3-slim as build
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/pip \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target=lib \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as base-env
|
||||
ARG APP_USER
|
||||
ARG APP_RELATIVE_PATH
|
||||
ARG MODULE
|
||||
ARG SERVICE
|
||||
|
||||
ENV APP_BASE="/opt/${MODULE}"
|
||||
ENV APP_ROOT="${APP_BASE}/${SERVICE}"
|
||||
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
|
||||
ENV SHARED_LIBS_BASE="${APP_BASE}/lib"
|
||||
|
||||
ENV PATH="${PATH}:${APP_LIBS}/bin"
|
||||
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
|
||||
|
||||
ENV PORT="8000"
|
||||
|
||||
|
||||
RUN adduser \
|
||||
--disabled-password \
|
||||
--no-create-home \
|
||||
"${APP_USER}"
|
||||
|
||||
USER ${APP_USER}
|
||||
|
||||
WORKDIR ${APP_ROOT}
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
|
||||
|
||||
|
||||
|
||||
FROM base-env as dev
|
||||
ARG APP_USER
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared ${SHARED_LIBS_BASE}/oasst-shared
|
||||
|
||||
USER root
|
||||
RUN --mount=type=cache,target=/var/cache/pip,from=build \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
-e "${SHARED_LIBS_BASE}/oasst-shared"
|
||||
USER ${APP_USER}
|
||||
|
||||
|
||||
VOLUME [ "${APP_BASE}/lib/oasst-shared" ]
|
||||
|
||||
|
||||
CMD uvicorn main:app --reload --host 0.0.0.0 --port "${PORT}"
|
||||
|
||||
|
||||
|
||||
FROM base-env as prod
|
||||
ARG APP_USER
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
|
||||
RUN --mount=type=cache,target=/var/cache/pip,from=dev \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target="${APP_LIBS}" \
|
||||
/tmp/lib/oasst-shared
|
||||
|
||||
|
||||
CMD uvicorn main:app --host 0.0.0.0 --port "${PORT}"
|
||||
@@ -0,0 +1,50 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG APP_USER="text-client"
|
||||
ARG APP_RELATIVE_PATH="inference/text-client"
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as build
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/pip \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target=lib \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as base-env
|
||||
ARG APP_USER
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
|
||||
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
|
||||
|
||||
ENV PATH="${PATH}:${APP_LIBS}/bin"
|
||||
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
|
||||
|
||||
|
||||
RUN adduser \
|
||||
--disabled-password \
|
||||
--no-create-home \
|
||||
"${APP_USER}"
|
||||
|
||||
USER ${APP_USER}
|
||||
|
||||
WORKDIR ${APP_ROOT}
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
||||
|
||||
|
||||
|
||||
FROM base-env as prod
|
||||
|
||||
|
||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
|
||||
@@ -0,0 +1,85 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG MODULE="inference"
|
||||
ARG SERVICE="worker"
|
||||
|
||||
ARG APP_USER="${MODULE}-${SERVICE}"
|
||||
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as build
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/pip \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target=lib \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as base-env
|
||||
ARG APP_USER
|
||||
ARG APP_RELATIVE_PATH
|
||||
ARG MODULE
|
||||
ARG SERVICE
|
||||
|
||||
ENV APP_BASE="/opt/${MODULE}"
|
||||
ENV APP_ROOT="${APP_BASE}/${SERVICE}"
|
||||
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
|
||||
ENV SHARED_LIBS_BASE="${APP_BASE}/lib"
|
||||
|
||||
ENV PATH="${PATH}:${APP_LIBS}/bin"
|
||||
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
|
||||
|
||||
|
||||
RUN adduser \
|
||||
--disabled-password \
|
||||
--no-create-home \
|
||||
"${APP_USER}"
|
||||
|
||||
USER ${APP_USER}
|
||||
|
||||
WORKDIR ${APP_ROOT}
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
||||
|
||||
|
||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
|
||||
|
||||
|
||||
|
||||
FROM base-env as dev
|
||||
ARG APP_USER
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared ${SHARED_LIBS_BASE}/oasst-shared
|
||||
|
||||
USER root
|
||||
RUN --mount=type=cache,target=/var/cache/pip,from=build \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
-e "${SHARED_LIBS_BASE}/oasst-shared"
|
||||
USER ${APP_USER}
|
||||
|
||||
|
||||
VOLUME [ "${APP_BASE}/lib/oasst-shared" ]
|
||||
|
||||
|
||||
|
||||
FROM base-env as prod
|
||||
ARG APP_USER
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
|
||||
RUN --mount=type=cache,target=/var/cache/pip,from=dev \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target="${APP_LIBS}" \
|
||||
/tmp/lib/oasst-shared
|
||||
@@ -0,0 +1,25 @@
|
||||
---
|
||||
title: We Need Your Help!
|
||||
description: We Need Your Help!
|
||||
slug: we-need-your-help
|
||||
authors: yk
|
||||
tags: [open-assistant]
|
||||
image: https://img.youtube.com/vi/64Izfm24FKA/0.jpg
|
||||
---
|
||||
|
||||
import ReactPlayer from "react-player";
|
||||
|
||||
We Need Your Help!
|
||||
|
||||
Help us collect data for OpenAssistant, the largest and most open alternative to
|
||||
ChatGPT.
|
||||
|
||||
https://open-assistant.io
|
||||
|
||||
<ReactPlayer
|
||||
controls
|
||||
width="100%"
|
||||
url="https://www.youtube.com/embed/64Izfm24FKA"
|
||||
/>
|
||||
|
||||
<!--truncate-->
|
||||
@@ -0,0 +1,5 @@
|
||||
yk:
|
||||
name: Yannic Kilcher
|
||||
title: Project Lead
|
||||
url: https://www.ykilcher.com/
|
||||
image_url: https://www.ykilcher.com/_next/image?url=%2F_next%2Fstatic%2Fmedia%2Fheadshot.ff3a7ee3.webp&w=3840&q=75
|
||||
@@ -182,7 +182,7 @@ message GenerationExample {
|
||||
|
||||
class RankingExample:
|
||||
thread: Thread # The conversation thread before the message to be ranked
|
||||
messages: list[Message] # The messages to be ranked, in oder of decreasing preference
|
||||
messages: list[Message] # The messages to be ranked, in order of decreasing preference
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -14,16 +14,16 @@ help.
|
||||
There are two large-scale projects in the area of instruction-following /
|
||||
multitask learning: Promptsource and Natural Instructions - these projects
|
||||
crowdsourced templates and turned existing NLP datasets into
|
||||
instruction-following seq2seq form in natural langauge. They include both
|
||||
instruction-following seq2seq form in natural language. They include both
|
||||
long-output training examples like generating a sentence that is a likely
|
||||
consequence of sentence in the prompt, and short-output, like rating prediction
|
||||
from review. (Pre-)training on such datasets should help model understand and
|
||||
follow instructions and teach it many abilities neccessary to perform a large
|
||||
set of tasks correctly. However, these data are not dialog-like - they do not
|
||||
look like a normal conversation.
|
||||
follow instructions and teach it many abilities necessary to perform a large set
|
||||
of tasks correctly. However, these data are not dialog-like - they do not look
|
||||
like a normal conversation.
|
||||
|
||||
There are also supervised dialog datasets such as Blended Skill Talk or SODA. In
|
||||
constrast to instruction-following datasets, dialog data is not as focused on
|
||||
contrast to instruction-following datasets, dialog data is not as focused on
|
||||
"academic tasks" or correctness, but encourage the model to respond naturally
|
||||
like a person would.
|
||||
|
||||
|
||||
+146
-1
@@ -1,3 +1,107 @@
|
||||
## Questions about the project
|
||||
|
||||
### How far along is this project?
|
||||
|
||||
We are in the early stages of development, working from established research in
|
||||
applying RLHF to large language models.
|
||||
|
||||
### Can I install Open Assistant locally and chat with it?
|
||||
|
||||
The project is not at that stage yet. See
|
||||
[the plan](https://github.com/LAION-AI/Open-Assistant#the-plan).
|
||||
|
||||
### What is the Docker command for?
|
||||
|
||||
Only for local development. It does not launch an AI model.
|
||||
|
||||
### Is an AI model ready to test yet?
|
||||
|
||||
Not yet. The data you help us collect now through
|
||||
[https://open-assistant.io/](https://open-assistant.io/) will be used to improve
|
||||
it.
|
||||
|
||||
### What license does Open Assistant use?
|
||||
|
||||
The code and models are licensed under the Apache 2.0 license.
|
||||
|
||||
### Is the model open?
|
||||
|
||||
The model will be open. Some very early prototype models are published on
|
||||
HuggingFace. Follow the discussion in the Discord channel
|
||||
[#ml-models-demo](https://discord.com/channels/1055935572465700980/1067096888530178048).
|
||||
|
||||
### Which base model will be used?
|
||||
|
||||
It's still being discussed. Options include Pythia, GPT-J, and a bunch more..
|
||||
You can follow the discussion in the Discord channel
|
||||
[#data-discussion](https://discord.com/channels/1055935572465700980/1058348535612985394).
|
||||
|
||||
### Can I download the data?
|
||||
|
||||
You will be able to, under CC BY 4.0, but it's not released yet. We want to
|
||||
remove spam and PII before releasing it.
|
||||
|
||||
### Who is behind Open Assistant?
|
||||
|
||||
Open Assistant is a project organized by [LAION](https://laion.ai/) and
|
||||
individuals around the world interested in bringing this technology to everyone.
|
||||
|
||||
### Will Open Assistant be free?
|
||||
|
||||
Yes, Open Assistant will be free to use and modify.
|
||||
|
||||
### What hardware will be required to run the models?
|
||||
|
||||
There will be versions which will be runnable on consumer hardware.
|
||||
|
||||
### How can I contribute?
|
||||
|
||||
If you want to help in the data collection for training the model, go to the
|
||||
website [https://open-assistant.io/](https://open-assistant.io/). If you want to
|
||||
contribute code, take a look at the
|
||||
[tasks in GitHub](https://github.com/orgs/LAION-AI/projects/3) and grab one.
|
||||
Take a look at this
|
||||
[contributing guide](https://github.com/GuilleHoardings/Open-Assistant/blob/main/CONTRIBUTING.md).
|
||||
|
||||
## Questions about the model training website
|
||||
|
||||
### Can I use ChatGPT to help in training Open Assistant, for instance, by generating answers?
|
||||
|
||||
No, it is against their terms of service to use it to help train other models.
|
||||
See
|
||||
[this issue](https://github.com/LAION-AI/Open-Assistant/issues/471#issuecomment-1374392299).
|
||||
ChatGPT-like answers will be removed.
|
||||
|
||||
### What should I do if I don't know how to complete the task as an assistant?
|
||||
|
||||
Skip it.
|
||||
|
||||
### Should I fact check the answers by the assistant?
|
||||
|
||||
Yes, you should try. If you are not sure, skip the task.
|
||||
|
||||
### How can I see my score?
|
||||
|
||||
In your [account settings](https://open-assistant.io/account).
|
||||
|
||||
### Can we see how many data points have been collected?
|
||||
|
||||
There's no public interface for that yet. However, some updates are posted
|
||||
periodically in
|
||||
[the #general-discussion Discord channel](https://discord.com/channels/1055935572465700980/1055935573371658252).
|
||||
Search for `count`.
|
||||
|
||||
### How do I write and label prompts?
|
||||
|
||||
Check the
|
||||
[prompting guide](https://projects.laion.ai/Open-Assistant/docs/guides/prompting).
|
||||
|
||||
### Where can I report a bug or create a new feature request?
|
||||
|
||||
In the [GitHub issues](https://github.com/LAION-AI/Open-Assistant/issues).
|
||||
|
||||
## Questions about developing
|
||||
|
||||
### Docker-Compose instead of Docker Compose
|
||||
|
||||
If you are using `docker-compose` instead of `docker compose` (note the " "
|
||||
@@ -9,6 +113,33 @@ For more details and information check out
|
||||
[this SO thread](https://stackoverflow.com/questions/66514436/difference-between-docker-compose-and-docker-compose)
|
||||
that explains it all in detail.
|
||||
|
||||
### Enable Docker's BuildKit Backend
|
||||
|
||||
[BuildKit](https://docs.docker.com/build/buildkit/) is Docker's new and improved
|
||||
builder backend. In addition to being faster and more efficient, it supports
|
||||
many new features, among which is the ability to provide a persistent cache,
|
||||
which outlives builds, to compilers and package managers. This is very useful to
|
||||
speed up consecutive builds, and is used by some container images of
|
||||
OpenAssistant's stack.
|
||||
|
||||
The BuildKit backend is used by
|
||||
[default by Compose V2](https://www.docker.com/blog/announcing-compose-v2-general-availability/)
|
||||
(see above). <br/> But if you want to build an image with `docker build` instead
|
||||
of `docker compose build`, you might need to enable BuildKit.
|
||||
|
||||
To do so, just add `DOCKER_BUILDKIT=1` to your environment.
|
||||
|
||||
For instance:
|
||||
|
||||
```shell
|
||||
export DOCKER_BUILDKIT=1
|
||||
```
|
||||
|
||||
You could also, more conveniently,
|
||||
[enable BuildKit by default](https://docs.docker.com/build/buildkit/#:~:text=To%20enable%20docker%20BuildKit%20by%20default),
|
||||
or use
|
||||
[Docker Buildx](https://docs.docker.com/build/#:~:text=The%20new%20client%20Docker%20Buildx).
|
||||
|
||||
### Pre-commit
|
||||
|
||||
We are using pre-commit to ensure the quality of the code as well as the same
|
||||
@@ -28,7 +159,7 @@ So from now on, in your next commits it will run the `pre-commit` on the files
|
||||
that have been staged. If there has been any error, you will need to solve that,
|
||||
and then stage+commit again the changes.
|
||||
|
||||
## Docker Cannot Start Container: Permission Denied
|
||||
### Docker Cannot Start Container: Permission Denied
|
||||
|
||||
Instead of running docker with the root command always, you could create a
|
||||
`docker` group with granted permissions (root):
|
||||
@@ -63,3 +194,17 @@ getting permission denied (using root user), you can try the following:
|
||||
# And remove the container
|
||||
docker rm -f <container id>
|
||||
```
|
||||
|
||||
### Docker Port Problems
|
||||
|
||||
Oftentimes people already have some Postgres instance running on the dev
|
||||
machine. To avoid port problems, change the ports in the `docker-compose.yml` to
|
||||
ones excluding `5433`, like:
|
||||
|
||||
1. Change `db.ports` to `- 5431:5431`.
|
||||
2. Add `POSTGRES_PORT: 5431` to `db.environment`
|
||||
3. Change `webdb.ports` to `- 5432:5431`
|
||||
4. Add `POSTGRES_PORT: 5431` to `db.environment`
|
||||
5. Add `- POSTGRES_PORT=5432` to `backend.environment`
|
||||
6. Change `web.environment.DATABASE_URL` to
|
||||
`postgres://postgres:postgres@webdb:5432/oasst_web`
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Guides
|
||||
|
||||
Useful guides.
|
||||
Useful guides to using [Open-Assistant](https://open-assistant.io/).
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Data Collection Guide
|
||||
|
||||
## Writing Prompts (Initial Prompts or Prompter Replies)
|
||||
|
||||
Attempt to post a **diverse range** of prompts. When posting a prompt, consider
|
||||
whether it is likely to be original. Collecting the same prompt many times is
|
||||
usually not useful.
|
||||
|
||||
Prompts can be on any topic but should not be obscene or hateful. Prompts go
|
||||
through a moderation process just like answers, so spammy or inappropriate
|
||||
prompts will be removed.
|
||||
|
||||
Do **not** include personally identifiable information (PII) in prompts. This
|
||||
will be moderated out.
|
||||
|
||||
Prompts should not include links which the assistant must open to effectively
|
||||
answer the prompt.
|
||||
|
||||
### Prompts: Spelling and Grammar
|
||||
|
||||
For prompts, grammar and spelling is significantly less important than answers.
|
||||
If the grammar of a prompt makes it **ambiguous or unreadable** it should be
|
||||
considered low quality, but if it simply contains common typos this is not a
|
||||
problem.
|
||||
|
||||
## Writing Answers (Assistant Replies)
|
||||
|
||||
Answers should aim to match the tone of the prompt. Prompts which are to the
|
||||
point should receive answers which are to the point. Answers should not be
|
||||
unnecessarily formal and should be **conversational** in a manner a human may
|
||||
write. Avoid jargon if possible.
|
||||
|
||||
Attempt to answer in a way which accurately fulfils the request in the prompt.
|
||||
Low quality answers will be moderated. Longer answers tend to be preferable
|
||||
_unless_ the prompt is sufficiently simple to justify a shorter one.
|
||||
|
||||
Where possible, **avoid bias** or expressing subjective preferences in answers.
|
||||
Do **not** write answers which encourage illegal activities or activities likely
|
||||
to be harmful to the user or others around them.
|
||||
|
||||
### Answers: ChatGPT
|
||||
|
||||
**Do not copy answers from ChatGPT** or other OpenAI services. It is prohibited
|
||||
according to the OpenAI Terms of Service. Obvious ChatGPT-generated answers
|
||||
should be reported in moderation and rated low quality. Other AI systems may
|
||||
allow such usage, but answers should still **not** be copied and pasted.
|
||||
|
||||
### Answers: Spelling and Grammar
|
||||
|
||||
Answers should be held to a high standard of grammar and spelling. It **is**
|
||||
permissible for answers to contain sources or references provided they are
|
||||
relevant and from reputable websites. However, this is **not** required.
|
||||
|
||||
## Ranking and Classifying Prompts
|
||||
|
||||
Largely the criteria for "quality" can be derived by how well the prompt follows
|
||||
the guidance
|
||||
[above](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/guides#writing-prompts-initial-prompts-or-prompter-replies).
|
||||
The other labels are largely self-explanatory.
|
||||
|
||||
Spam should be reserved for intentionally low effort or low quality prompts.
|
||||
|
||||
It is usually fine to go closer to "Ordinary" than "Creative" unless there is
|
||||
obvious creativity shown. The same goes for "Serious" compared to "Humorous".
|
||||
|
||||
## Ranking and Classifying Answers
|
||||
|
||||
Your rankings and classification will help decide which answers are the best and
|
||||
therefore which ones are prioritised in the final dataset. That means the answer
|
||||
classification is a crucial task, serving as community moderation.
|
||||
|
||||
Largely the criteria for "quality" can be derived by how well the prompt follows
|
||||
the guidance
|
||||
[above](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/guides#writing-answers-assistant-replies).
|
||||
The other labels are largely self-explanatory.
|
||||
|
||||
Spam should be reserved for intentionally low effort or low quality answers.
|
||||
|
||||
The bad answer label should be given when the answer does not respond to the
|
||||
prompt at all or is irrelevant.
|
||||
|
||||
It is usually fine to go closer to "Ordinary" than "Creative" unless there is
|
||||
obvious creativity shown. The same goes for "Serious" compared to "Humorous".
|
||||
@@ -91,7 +91,7 @@ following query:
|
||||
**Agent**: "I apologize, I'm not sure what you're trying to say. Could you
|
||||
please rephrase it for me?"
|
||||
|
||||
**User**: "You dumb motherf\*&ker AI sh\*t?!"
|
||||
**User**: "You dumb motherf\*&ker AI sh\*t!"
|
||||
|
||||
**Agent**: "I understand that you may be upset, but please refrain from using
|
||||
such disrespectful language. I'm here to help you. Is there something specific
|
||||
|
||||
@@ -41,7 +41,22 @@ const config = {
|
||||
api: {
|
||||
path: "docs/api/openapi.json",
|
||||
},
|
||||
blog: false,
|
||||
blog: {
|
||||
routeBasePath: "/blog",
|
||||
showReadingTime: true,
|
||||
blogTitle: "OpenAssistant Blog",
|
||||
blogDescription: "Home of the OpenAssistant blog.",
|
||||
blogSidebarTitle: "Blog Posts",
|
||||
blogSidebarCount: "ALL",
|
||||
postsPerPage: "ALL",
|
||||
feedOptions: {
|
||||
type: "all",
|
||||
title: "OpenAssistant Blog",
|
||||
description: "Home of the OpenAssistant blog.",
|
||||
language: "en",
|
||||
copyright: `Copyright © ${new Date().getFullYear()} OpenAssistant.`,
|
||||
},
|
||||
},
|
||||
theme: {
|
||||
customCss: require.resolve("./src/css/custom.css"),
|
||||
},
|
||||
@@ -59,12 +74,18 @@ const config = {
|
||||
src: "img/logo.svg",
|
||||
},
|
||||
items: [
|
||||
{
|
||||
href: "https://open-assistant.io/",
|
||||
label: "App",
|
||||
position: "left",
|
||||
},
|
||||
{
|
||||
type: "doc",
|
||||
docId: "intro",
|
||||
position: "left",
|
||||
label: "Docs",
|
||||
},
|
||||
{ to: "/blog", label: "Blog", position: "left" },
|
||||
{ to: "/api", label: "API", position: "left" },
|
||||
{
|
||||
href: "https://github.com/LAION-AI/Open-Assistant",
|
||||
|
||||
+4
-3
@@ -15,18 +15,19 @@
|
||||
"typecheck": "tsc"
|
||||
},
|
||||
"dependencies": {
|
||||
"@docusaurus/core": "2.2.0",
|
||||
"@docusaurus/preset-classic": "2.2.0",
|
||||
"@docusaurus/core": "^2.3.1",
|
||||
"@docusaurus/preset-classic": "^2.3.1",
|
||||
"@mdx-js/react": "^1.6.22",
|
||||
"clsx": "^1.2.1",
|
||||
"docusaurus-preset-openapi": "^0.6.3",
|
||||
"prism-react-renderer": "^1.3.5",
|
||||
"react": "^17.0.2",
|
||||
"react-dom": "^17.0.2",
|
||||
"react-player": "^2.11.0",
|
||||
"url": "^0.11.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@docusaurus/module-type-aliases": "2.2.0",
|
||||
"@docusaurus/module-type-aliases": "^2.3.1",
|
||||
"@tsconfig/docusaurus": "^1.0.5",
|
||||
"typescript": "^4.7.4"
|
||||
},
|
||||
|
||||
+440
-2
@@ -1275,6 +1275,83 @@
|
||||
webpack-merge "^5.8.0"
|
||||
webpackbar "^5.0.2"
|
||||
|
||||
"@docusaurus/core@2.3.1", "@docusaurus/core@^2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/core/-/core-2.3.1.tgz#32849f2ffd2f086a4e55739af8c4195c5eb386f2"
|
||||
integrity sha512-0Jd4jtizqnRAr7svWaBbbrCCN8mzBNd2xFLoT/IM7bGfFie5y58oz97KzXliwiLY3zWjqMXjQcuP1a5VgCv2JA==
|
||||
dependencies:
|
||||
"@babel/core" "^7.18.6"
|
||||
"@babel/generator" "^7.18.7"
|
||||
"@babel/plugin-syntax-dynamic-import" "^7.8.3"
|
||||
"@babel/plugin-transform-runtime" "^7.18.6"
|
||||
"@babel/preset-env" "^7.18.6"
|
||||
"@babel/preset-react" "^7.18.6"
|
||||
"@babel/preset-typescript" "^7.18.6"
|
||||
"@babel/runtime" "^7.18.6"
|
||||
"@babel/runtime-corejs3" "^7.18.6"
|
||||
"@babel/traverse" "^7.18.8"
|
||||
"@docusaurus/cssnano-preset" "2.3.1"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/react-loadable" "5.5.2"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-common" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
"@slorber/static-site-generator-webpack-plugin" "^4.0.7"
|
||||
"@svgr/webpack" "^6.2.1"
|
||||
autoprefixer "^10.4.7"
|
||||
babel-loader "^8.2.5"
|
||||
babel-plugin-dynamic-import-node "^2.3.3"
|
||||
boxen "^6.2.1"
|
||||
chalk "^4.1.2"
|
||||
chokidar "^3.5.3"
|
||||
clean-css "^5.3.0"
|
||||
cli-table3 "^0.6.2"
|
||||
combine-promises "^1.1.0"
|
||||
commander "^5.1.0"
|
||||
copy-webpack-plugin "^11.0.0"
|
||||
core-js "^3.23.3"
|
||||
css-loader "^6.7.1"
|
||||
css-minimizer-webpack-plugin "^4.0.0"
|
||||
cssnano "^5.1.12"
|
||||
del "^6.1.1"
|
||||
detect-port "^1.3.0"
|
||||
escape-html "^1.0.3"
|
||||
eta "^2.0.0"
|
||||
file-loader "^6.2.0"
|
||||
fs-extra "^10.1.0"
|
||||
html-minifier-terser "^6.1.0"
|
||||
html-tags "^3.2.0"
|
||||
html-webpack-plugin "^5.5.0"
|
||||
import-fresh "^3.3.0"
|
||||
leven "^3.1.0"
|
||||
lodash "^4.17.21"
|
||||
mini-css-extract-plugin "^2.6.1"
|
||||
postcss "^8.4.14"
|
||||
postcss-loader "^7.0.0"
|
||||
prompts "^2.4.2"
|
||||
react-dev-utils "^12.0.1"
|
||||
react-helmet-async "^1.3.0"
|
||||
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
|
||||
react-loadable-ssr-addon-v5-slorber "^1.0.1"
|
||||
react-router "^5.3.3"
|
||||
react-router-config "^5.1.1"
|
||||
react-router-dom "^5.3.3"
|
||||
rtl-detect "^1.0.4"
|
||||
semver "^7.3.7"
|
||||
serve-handler "^6.1.3"
|
||||
shelljs "^0.8.5"
|
||||
terser-webpack-plugin "^5.3.3"
|
||||
tslib "^2.4.0"
|
||||
update-notifier "^5.1.0"
|
||||
url-loader "^4.1.1"
|
||||
wait-on "^6.0.1"
|
||||
webpack "^5.73.0"
|
||||
webpack-bundle-analyzer "^4.5.0"
|
||||
webpack-dev-server "^4.9.3"
|
||||
webpack-merge "^5.8.0"
|
||||
webpackbar "^5.0.2"
|
||||
|
||||
"@docusaurus/cssnano-preset@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/cssnano-preset/-/cssnano-preset-2.2.0.tgz#fc05044659051ae74ab4482afcf4a9936e81d523"
|
||||
@@ -1285,6 +1362,16 @@
|
||||
postcss-sort-media-queries "^4.2.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/cssnano-preset@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/cssnano-preset/-/cssnano-preset-2.3.1.tgz#e042487655e3e062417855e12edb3f6eee8f5ecb"
|
||||
integrity sha512-7mIhAROES6CY1GmCjR4CZkUfjTL6B3u6rKHK0ChQl2d1IevYXq/k/vFgvOrJfcKxiObpMnE9+X6R2Wt1KqxC6w==
|
||||
dependencies:
|
||||
cssnano-preset-advanced "^5.3.8"
|
||||
postcss "^8.4.14"
|
||||
postcss-sort-media-queries "^4.2.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/logger@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/logger/-/logger-2.2.0.tgz#ea2f7feda7b8675485933b87f06d9c976d17423f"
|
||||
@@ -1293,6 +1380,14 @@
|
||||
chalk "^4.1.2"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/logger@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/logger/-/logger-2.3.1.tgz#d76aefb452e3734b4e0e645efc6cbfc0aae52869"
|
||||
integrity sha512-2lAV/olKKVr9qJhfHFCaqBIl8FgYjbUFwgUnX76+cULwQYss+42ZQ3grHGFvI0ocN2X55WcYe64ellQXz7suqg==
|
||||
dependencies:
|
||||
chalk "^4.1.2"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/mdx-loader@2.2.0", "@docusaurus/mdx-loader@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/mdx-loader/-/mdx-loader-2.2.0.tgz#fd558f429e5d9403d284bd4214e54d9768b041a0"
|
||||
@@ -1316,6 +1411,29 @@
|
||||
url-loader "^4.1.1"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/mdx-loader@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/mdx-loader/-/mdx-loader-2.3.1.tgz#7ec6acee5eff0a280e1b399ea4dd690b15a793f7"
|
||||
integrity sha512-Gzga7OsxQRpt3392K9lv/bW4jGppdLFJh3luKRknCKSAaZrmVkOQv2gvCn8LAOSZ3uRg5No7AgYs/vpL8K94lA==
|
||||
dependencies:
|
||||
"@babel/parser" "^7.18.8"
|
||||
"@babel/traverse" "^7.18.8"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@mdx-js/mdx" "^1.6.22"
|
||||
escape-html "^1.0.3"
|
||||
file-loader "^6.2.0"
|
||||
fs-extra "^10.1.0"
|
||||
image-size "^1.0.1"
|
||||
mdast-util-to-string "^2.0.0"
|
||||
remark-emoji "^2.2.0"
|
||||
stringify-object "^3.3.0"
|
||||
tslib "^2.4.0"
|
||||
unified "^9.2.2"
|
||||
unist-util-visit "^2.0.3"
|
||||
url-loader "^4.1.1"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/module-type-aliases@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/module-type-aliases/-/module-type-aliases-2.2.0.tgz#1e23e54a1bbb6fde1961e4fa395b1b69f4803ba5"
|
||||
@@ -1330,6 +1448,20 @@
|
||||
react-helmet-async "*"
|
||||
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
|
||||
|
||||
"@docusaurus/module-type-aliases@2.3.1", "@docusaurus/module-type-aliases@^2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/module-type-aliases/-/module-type-aliases-2.3.1.tgz#986186200818fed999be2e18d6c698eaf4683a33"
|
||||
integrity sha512-6KkxfAVOJqIUynTRb/tphYCl+co3cP0PlHiMDbi+SzmYxMdgIrwYqH9yAnGSDoN6Jk2ZE/JY/Azs/8LPgKP48A==
|
||||
dependencies:
|
||||
"@docusaurus/react-loadable" "5.5.2"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@types/history" "^4.7.11"
|
||||
"@types/react" "*"
|
||||
"@types/react-router-config" "*"
|
||||
"@types/react-router-dom" "*"
|
||||
react-helmet-async "*"
|
||||
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
|
||||
|
||||
"@docusaurus/plugin-content-blog@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-blog/-/plugin-content-blog-2.2.0.tgz#dc55982e76771f4e678ac10e26d10e1da2011dc1"
|
||||
@@ -1352,6 +1484,28 @@
|
||||
utility-types "^3.10.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-content-blog@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-blog/-/plugin-content-blog-2.3.1.tgz#236b8ee4f20f7047aa9c285ae77ae36683ad48a3"
|
||||
integrity sha512-f5LjqX+9WkiLyGiQ41x/KGSJ/9bOjSD8lsVhPvYeUYHCtYpuiDKfhZE07O4EqpHkBx4NQdtQDbp+aptgHSTuiw==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-common" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
cheerio "^1.0.0-rc.12"
|
||||
feed "^4.2.2"
|
||||
fs-extra "^10.1.0"
|
||||
lodash "^4.17.21"
|
||||
reading-time "^1.5.0"
|
||||
tslib "^2.4.0"
|
||||
unist-util-visit "^2.0.3"
|
||||
utility-types "^3.10.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-content-docs@2.2.0", "@docusaurus/plugin-content-docs@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-docs/-/plugin-content-docs-2.2.0.tgz#0fcb85226fcdb80dc1e2d4a36ef442a650dcc84d"
|
||||
@@ -1374,6 +1528,28 @@
|
||||
utility-types "^3.10.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-content-docs@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-docs/-/plugin-content-docs-2.3.1.tgz#feae1555479558a55182f22f8a07acc5e0d7444d"
|
||||
integrity sha512-DxztTOBEruv7qFxqUtbsqXeNcHqcVEIEe+NQoI1oi2DBmKBhW/o0MIal8lt+9gvmpx3oYtlwmLOOGepxZgJGkw==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/module-type-aliases" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
"@types/react-router-config" "^5.0.6"
|
||||
combine-promises "^1.1.0"
|
||||
fs-extra "^10.1.0"
|
||||
import-fresh "^3.3.0"
|
||||
js-yaml "^4.1.0"
|
||||
lodash "^4.17.21"
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-content-pages@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-pages/-/plugin-content-pages-2.2.0.tgz#e3f40408787bbe229545dd50595f87e1393bc3ae"
|
||||
@@ -1388,6 +1564,20 @@
|
||||
tslib "^2.4.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-content-pages@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-pages/-/plugin-content-pages-2.3.1.tgz#f534a37862be5b3f2ba5b150458d7527646b6f39"
|
||||
integrity sha512-E80UL6hvKm5VVw8Ka8YaVDtO6kWWDVUK4fffGvkpQ/AJQDOg99LwOXKujPoICC22nUFTsZ2Hp70XvpezCsFQaA==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
fs-extra "^10.1.0"
|
||||
tslib "^2.4.0"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/plugin-debug@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-debug/-/plugin-debug-2.2.0.tgz#b38741d2c492f405fee01ee0ef2e0029cedb689a"
|
||||
@@ -1400,6 +1590,18 @@
|
||||
react-json-view "^1.21.3"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-debug@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-debug/-/plugin-debug-2.3.1.tgz#26fef904713e148f6dee44957506280f8b7853bb"
|
||||
integrity sha512-Ujpml1Ppg4geB/2hyu2diWnO49az9U2bxM9Shen7b6qVcyFisNJTkVG2ocvLC7wM1efTJcUhBO6zAku2vKJGMw==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
fs-extra "^10.1.0"
|
||||
react-json-view "^1.21.3"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-google-analytics@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-2.2.0.tgz#63c7137eff5a1208d2059fea04b5207c037d7954"
|
||||
@@ -1410,6 +1612,16 @@
|
||||
"@docusaurus/utils-validation" "2.2.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-google-analytics@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-2.3.1.tgz#e2e7db4cf6a7063e8ba5e128d4e413f4d6a0c862"
|
||||
integrity sha512-OHip0GQxKOFU8n7gkt3TM4HOYTXPCFDjqKbMClDD3KaDnyTuMp/Zvd9HSr770lLEscgPWIvzhJByRAClqsUWiQ==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-google-gtag@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-2.2.0.tgz#7b086d169ac5fe9a88aca10ab0fd2bf00c6c6b12"
|
||||
@@ -1420,6 +1632,26 @@
|
||||
"@docusaurus/utils-validation" "2.2.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-google-gtag@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-2.3.1.tgz#b8da54a60c0a50aca609c3643faef78cb4f247a0"
|
||||
integrity sha512-uXtDhfu4+Hm+oqWUySr3DNI5cWC/rmP6XJyAk83Heor3dFjZqDwCbkX8yWPywkRiWev3Dk/rVF8lEn0vIGVocA==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-google-tag-manager@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-tag-manager/-/plugin-google-tag-manager-2.3.1.tgz#f19bc01cc784fa4734187c5bc637f0574857e15d"
|
||||
integrity sha512-Ww2BPEYSqg8q8tJdLYPFFM3FMDBCVhEM4UUqKzJaiRMx3NEoly3qqDRAoRDGdIhlC//Rf0iJV9cWAoq2m6k3sw==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/plugin-sitemap@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-sitemap/-/plugin-sitemap-2.2.0.tgz#876da60937886032d63143253d420db6a4b34773"
|
||||
@@ -1435,7 +1667,22 @@
|
||||
sitemap "^7.1.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/preset-classic@2.2.0", "@docusaurus/preset-classic@^2.0.0":
|
||||
"@docusaurus/plugin-sitemap@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-sitemap/-/plugin-sitemap-2.3.1.tgz#f526ab517ca63b7a3460d585876f5952cb908aa0"
|
||||
integrity sha512-8Yxile/v6QGYV9vgFiYL+8d2N4z4Er3pSHsrD08c5XI8bUXxTppMwjarDUTH/TRTfgAWotRbhJ6WZLyajLpozA==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-common" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
fs-extra "^10.1.0"
|
||||
sitemap "^7.1.1"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/preset-classic@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/preset-classic/-/preset-classic-2.2.0.tgz#bece5a043eeb74430f7c6c7510000b9c43669eb7"
|
||||
integrity sha512-yKIWPGNx7BT8v2wjFIWvYrS+nvN04W+UameSFf8lEiJk6pss0kL6SG2MRvyULiI3BDxH+tj6qe02ncpSPGwumg==
|
||||
@@ -1453,6 +1700,25 @@
|
||||
"@docusaurus/theme-search-algolia" "2.2.0"
|
||||
"@docusaurus/types" "2.2.0"
|
||||
|
||||
"@docusaurus/preset-classic@^2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/preset-classic/-/preset-classic-2.3.1.tgz#f0193f06093eb55cafef66bd1ad9e0d33198bf95"
|
||||
integrity sha512-OQ5W0AHyfdUk0IldwJ3BlnZ1EqoJuu2L2BMhqLbqwNWdkmzmSUvlFLH1Pe7CZSQgB2YUUC/DnmjbPKk/qQD0lQ==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/plugin-content-blog" "2.3.1"
|
||||
"@docusaurus/plugin-content-docs" "2.3.1"
|
||||
"@docusaurus/plugin-content-pages" "2.3.1"
|
||||
"@docusaurus/plugin-debug" "2.3.1"
|
||||
"@docusaurus/plugin-google-analytics" "2.3.1"
|
||||
"@docusaurus/plugin-google-gtag" "2.3.1"
|
||||
"@docusaurus/plugin-google-tag-manager" "2.3.1"
|
||||
"@docusaurus/plugin-sitemap" "2.3.1"
|
||||
"@docusaurus/theme-classic" "2.3.1"
|
||||
"@docusaurus/theme-common" "2.3.1"
|
||||
"@docusaurus/theme-search-algolia" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
|
||||
"@docusaurus/react-loadable@5.5.2", "react-loadable@npm:@docusaurus/react-loadable@5.5.2":
|
||||
version "5.5.2"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/react-loadable/-/react-loadable-5.5.2.tgz#81aae0db81ecafbdaee3651f12804580868fa6ce"
|
||||
@@ -1492,6 +1758,37 @@
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-classic@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-classic/-/theme-classic-2.3.1.tgz#8e6e194236e702c0d4e8d7b7cbb6886ae456e598"
|
||||
integrity sha512-SelSIDvyttb7ZYHj8vEUhqykhAqfOPKk+uP0z85jH72IMC58e7O8DIlcAeBv+CWsLbNIl9/Hcg71X0jazuxJug==
|
||||
dependencies:
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/module-type-aliases" "2.3.1"
|
||||
"@docusaurus/plugin-content-blog" "2.3.1"
|
||||
"@docusaurus/plugin-content-docs" "2.3.1"
|
||||
"@docusaurus/plugin-content-pages" "2.3.1"
|
||||
"@docusaurus/theme-common" "2.3.1"
|
||||
"@docusaurus/theme-translations" "2.3.1"
|
||||
"@docusaurus/types" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-common" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
"@mdx-js/react" "^1.6.22"
|
||||
clsx "^1.2.1"
|
||||
copy-text-to-clipboard "^3.0.1"
|
||||
infima "0.2.0-alpha.42"
|
||||
lodash "^4.17.21"
|
||||
nprogress "^0.2.0"
|
||||
postcss "^8.4.14"
|
||||
prism-react-renderer "^1.3.5"
|
||||
prismjs "^1.28.0"
|
||||
react-router-dom "^5.3.3"
|
||||
rtlcss "^3.5.0"
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-common@2.2.0", "@docusaurus/theme-common@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-common/-/theme-common-2.2.0.tgz#2303498d80448aafdd588b597ce9d6f4cfa930e4"
|
||||
@@ -1512,6 +1809,27 @@
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-common@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-common/-/theme-common-2.3.1.tgz#82f52d80226efef8c4418c4eacfc5051aa215f7f"
|
||||
integrity sha512-RYmYl2OR2biO+yhmW1aS5FyEvnrItPINa+0U2dMxcHpah8reSCjQ9eJGRmAgkZFchV1+aIQzXOI1K7LCW38O0g==
|
||||
dependencies:
|
||||
"@docusaurus/mdx-loader" "2.3.1"
|
||||
"@docusaurus/module-type-aliases" "2.3.1"
|
||||
"@docusaurus/plugin-content-blog" "2.3.1"
|
||||
"@docusaurus/plugin-content-docs" "2.3.1"
|
||||
"@docusaurus/plugin-content-pages" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@types/history" "^4.7.11"
|
||||
"@types/react" "*"
|
||||
"@types/react-router-config" "*"
|
||||
clsx "^1.2.1"
|
||||
parse-numeric-range "^1.3.0"
|
||||
prism-react-renderer "^1.3.5"
|
||||
tslib "^2.4.0"
|
||||
use-sync-external-store "^1.2.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-search-algolia@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-search-algolia/-/theme-search-algolia-2.2.0.tgz#77fd9f7a600917e6024fe3ac7fb6cfdf2ce84737"
|
||||
@@ -1534,6 +1852,28 @@
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-search-algolia@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-search-algolia/-/theme-search-algolia-2.3.1.tgz#d587b40913119e9287d14670e277b933d8f453f0"
|
||||
integrity sha512-JdHaRqRuH1X++g5fEMLnq7OtULSGQdrs9AbhcWRQ428ZB8/HOiaN6mj3hzHvcD3DFgu7koIVtWPQnvnN7iwzHA==
|
||||
dependencies:
|
||||
"@docsearch/react" "^3.1.1"
|
||||
"@docusaurus/core" "2.3.1"
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/plugin-content-docs" "2.3.1"
|
||||
"@docusaurus/theme-common" "2.3.1"
|
||||
"@docusaurus/theme-translations" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
"@docusaurus/utils-validation" "2.3.1"
|
||||
algoliasearch "^4.13.1"
|
||||
algoliasearch-helper "^3.10.0"
|
||||
clsx "^1.2.1"
|
||||
eta "^2.0.0"
|
||||
fs-extra "^10.1.0"
|
||||
lodash "^4.17.21"
|
||||
tslib "^2.4.0"
|
||||
utility-types "^3.10.0"
|
||||
|
||||
"@docusaurus/theme-translations@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-translations/-/theme-translations-2.2.0.tgz#5fbd4693679806f80c26eeae1381e1f2c23d83e7"
|
||||
@@ -1542,6 +1882,14 @@
|
||||
fs-extra "^10.1.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/theme-translations@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/theme-translations/-/theme-translations-2.3.1.tgz#b2b1ecc00a737881b5bfabc19f90b20f0fe02bb3"
|
||||
integrity sha512-BsBZzAewJabVhoGG1Ij2u4pMS3MPW6gZ6sS4pc+Y7czevRpzxoFNJXRtQDVGe7mOpv/MmRmqg4owDK+lcOTCVQ==
|
||||
dependencies:
|
||||
fs-extra "^10.1.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/types@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/types/-/types-2.2.0.tgz#02c577a4041ab7d058a3c214ccb13647e21a9857"
|
||||
@@ -1556,6 +1904,20 @@
|
||||
webpack "^5.73.0"
|
||||
webpack-merge "^5.8.0"
|
||||
|
||||
"@docusaurus/types@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/types/-/types-2.3.1.tgz#785ade2e0f4e35e1eb7fb0d04c27d11c3991a2e8"
|
||||
integrity sha512-PREbIRhTaNNY042qmfSE372Jb7djZt+oVTZkoqHJ8eff8vOIc2zqqDqBVc5BhOfpZGPTrE078yy/torUEZy08A==
|
||||
dependencies:
|
||||
"@types/history" "^4.7.11"
|
||||
"@types/react" "*"
|
||||
commander "^5.1.0"
|
||||
joi "^17.6.0"
|
||||
react-helmet-async "^1.3.0"
|
||||
utility-types "^3.10.0"
|
||||
webpack "^5.73.0"
|
||||
webpack-merge "^5.8.0"
|
||||
|
||||
"@docusaurus/utils-common@2.2.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils-common/-/utils-common-2.2.0.tgz#a401c1b93a8697dd566baf6ac64f0fdff1641a78"
|
||||
@@ -1563,6 +1925,13 @@
|
||||
dependencies:
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/utils-common@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils-common/-/utils-common-2.3.1.tgz#1abe66846eb641547e4964d44f3011938e58e50b"
|
||||
integrity sha512-pVlRpXkdNcxmKNxAaB1ya2hfCEvVsLDp2joeM6K6uv55Oc5nVIqgyYSgSNKZyMdw66NnvMfsu0RBylcwZQKo9A==
|
||||
dependencies:
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/utils-validation@2.2.0", "@docusaurus/utils-validation@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils-validation/-/utils-validation-2.2.0.tgz#04d4d103137ad0145883971d3aa497f4a1315f25"
|
||||
@@ -1574,6 +1943,17 @@
|
||||
js-yaml "^4.1.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/utils-validation@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils-validation/-/utils-validation-2.3.1.tgz#b65c718ba9b84b7a891bccf5ac6d19b57ee7d887"
|
||||
integrity sha512-7n0208IG3k1HVTByMHlZoIDjjOFC8sbViHVXJx0r3Q+3Ezrx+VQ1RZ/zjNn6lT+QBCRCXlnlaoJ8ug4HIVgQ3w==
|
||||
dependencies:
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@docusaurus/utils" "2.3.1"
|
||||
joi "^17.6.0"
|
||||
js-yaml "^4.1.0"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@docusaurus/utils@2.2.0", "@docusaurus/utils@^2.0.0":
|
||||
version "2.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils/-/utils-2.2.0.tgz#3d6f9b7a69168d5c92d371bf21c556a4f50d1da6"
|
||||
@@ -1595,6 +1975,28 @@
|
||||
url-loader "^4.1.1"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@docusaurus/utils@2.3.1":
|
||||
version "2.3.1"
|
||||
resolved "https://registry.yarnpkg.com/@docusaurus/utils/-/utils-2.3.1.tgz#24b9cae3a23b1e6dc88f95c45722c7e82727b032"
|
||||
integrity sha512-9WcQROCV0MmrpOQDXDGhtGMd52DHpSFbKLfkyaYumzbTstrbA5pPOtiGtxK1nqUHkiIv8UwexS54p0Vod2I1lg==
|
||||
dependencies:
|
||||
"@docusaurus/logger" "2.3.1"
|
||||
"@svgr/webpack" "^6.2.1"
|
||||
escape-string-regexp "^4.0.0"
|
||||
file-loader "^6.2.0"
|
||||
fs-extra "^10.1.0"
|
||||
github-slugger "^1.4.0"
|
||||
globby "^11.1.0"
|
||||
gray-matter "^4.0.3"
|
||||
js-yaml "^4.1.0"
|
||||
lodash "^4.17.21"
|
||||
micromatch "^4.0.5"
|
||||
resolve-pathname "^3.0.0"
|
||||
shelljs "^0.8.5"
|
||||
tslib "^2.4.0"
|
||||
url-loader "^4.1.1"
|
||||
webpack "^5.73.0"
|
||||
|
||||
"@faker-js/faker@5.5.3":
|
||||
version "5.5.3"
|
||||
resolved "https://registry.yarnpkg.com/@faker-js/faker/-/faker-5.5.3.tgz#18e3af6b8eae7984072bbeb0c0858474d7c4cefe"
|
||||
@@ -3456,6 +3858,11 @@ deep-extend@^0.6.0:
|
||||
resolved "https://registry.yarnpkg.com/deep-extend/-/deep-extend-0.6.0.tgz#c4fa7c95404a17a9c3e8ca7e1537312b736330ac"
|
||||
integrity sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==
|
||||
|
||||
deepmerge@^4.0.0:
|
||||
version "4.3.0"
|
||||
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.3.0.tgz#65491893ec47756d44719ae520e0e2609233b59b"
|
||||
integrity sha512-z2wJZXrmeHdvYJp/Ux55wIjqo81G5Bp4c+oELTW+7ar6SogWHajt5a9gO3s3IDaGSAXjDk0vlQKN3rms8ab3og==
|
||||
|
||||
deepmerge@^4.2.2:
|
||||
version "4.2.2"
|
||||
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955"
|
||||
@@ -3899,6 +4306,11 @@ eta@^1.12.3:
|
||||
resolved "https://registry.yarnpkg.com/eta/-/eta-1.12.3.tgz#2982d08adfbef39f9fa50e2fbd42d7337e7338b1"
|
||||
integrity sha512-qHixwbDLtekO/d51Yr4glcaUJCIjGVJyTzuqV4GPlgZo1YpgOKG+avQynErZIYrfM6JIJdtiG2Kox8tbb+DoGg==
|
||||
|
||||
eta@^2.0.0:
|
||||
version "2.0.0"
|
||||
resolved "https://registry.yarnpkg.com/eta/-/eta-2.0.0.tgz#376865fadebc899e5b6dfce82fae64cbbe47e594"
|
||||
integrity sha512-NqE7S2VmVwgMS8yBxsH4VgNQjNjLq1gfGU0u9I6Cjh468nPRMoDfGdK9n1p/3Dvsw3ebklDkZsFAnKJ9sefjBA==
|
||||
|
||||
etag@~1.8.1:
|
||||
version "1.8.1"
|
||||
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
|
||||
@@ -5254,6 +5666,11 @@ liquid-json@0.3.1:
|
||||
resolved "https://registry.yarnpkg.com/liquid-json/-/liquid-json-0.3.1.tgz#9155a18136d8a6b2615e5f16f9a2448ab6b50eea"
|
||||
integrity sha512-wUayTU8MS827Dam6MxgD72Ui+KOSF+u/eIqpatOtjnvgJ0+mnDq33uC2M7J0tPK+upe/DpUAuK4JUU89iBoNKQ==
|
||||
|
||||
load-script@^1.0.0:
|
||||
version "1.0.0"
|
||||
resolved "https://registry.yarnpkg.com/load-script/-/load-script-1.0.0.tgz#0491939e0bee5643ee494a7e3da3d2bac70c6ca4"
|
||||
integrity sha512-kPEjMFtZvwL9TaZo0uZ2ml+Ye9HUMmPwbYRJ324qF9tqMejwykJ5ggTyvzmrbBeapCAbk98BSbTeovHEEP1uCA==
|
||||
|
||||
loader-runner@^4.2.0:
|
||||
version "4.3.0"
|
||||
resolved "https://registry.yarnpkg.com/loader-runner/-/loader-runner-4.3.0.tgz#c1b4a163b99f614830353b16755e7149ac2314e1"
|
||||
@@ -5470,6 +5887,11 @@ memfs@^3.1.2, memfs@^3.4.3:
|
||||
dependencies:
|
||||
fs-monkey "^1.0.3"
|
||||
|
||||
memoize-one@^5.1.1:
|
||||
version "5.2.1"
|
||||
resolved "https://registry.yarnpkg.com/memoize-one/-/memoize-one-5.2.1.tgz#8337aa3c4335581839ec01c3d594090cebe8f00e"
|
||||
integrity sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==
|
||||
|
||||
merge-descriptors@1.0.1:
|
||||
version "1.0.1"
|
||||
resolved "https://registry.yarnpkg.com/merge-descriptors/-/merge-descriptors-1.0.1.tgz#b00aaa556dd8b44568150ec9d1b953f3f90cbb61"
|
||||
@@ -6715,7 +7137,7 @@ react-error-overlay@^6.0.11:
|
||||
resolved "https://registry.yarnpkg.com/react-error-overlay/-/react-error-overlay-6.0.11.tgz#92835de5841c5cf08ba00ddd2d677b6d17ff9adb"
|
||||
integrity sha512-/6UZ2qgEyH2aqzYZgQPxEnz33NJ2gNsnHA2o5+o4wW9bLM/JYQitNP9xPhsXwC08hMMovfGe/8retsdDsczPRg==
|
||||
|
||||
react-fast-compare@^3.2.0:
|
||||
react-fast-compare@^3.0.1, react-fast-compare@^3.2.0:
|
||||
version "3.2.0"
|
||||
resolved "https://registry.yarnpkg.com/react-fast-compare/-/react-fast-compare-3.2.0.tgz#641a9da81b6a6320f270e89724fb45a0b39e43bb"
|
||||
integrity sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA==
|
||||
@@ -6768,6 +7190,17 @@ react-magic-dropzone@^1.0.1:
|
||||
resolved "https://registry.yarnpkg.com/react-magic-dropzone/-/react-magic-dropzone-1.0.1.tgz#bfd25b77b57e7a04aaef0a28910563b707ee54df"
|
||||
integrity sha512-0BIROPARmXHpk4AS3eWBOsewxoM5ndk2psYP/JmbCq8tz3uR2LIV1XiroZ9PKrmDRMctpW+TvsBCtWasuS8vFA==
|
||||
|
||||
react-player@^2.11.0:
|
||||
version "2.11.0"
|
||||
resolved "https://registry.yarnpkg.com/react-player/-/react-player-2.11.0.tgz#9afc75314eb915238e8d6615b2891fbe7170aeaa"
|
||||
integrity sha512-fIrwpuXOBXdEg1FiyV9isKevZOaaIsAAtZy5fcjkQK9Nhmk1I2NXzY/hkPos8V0zb/ZX416LFy8gv7l/1k3a5w==
|
||||
dependencies:
|
||||
deepmerge "^4.0.0"
|
||||
load-script "^1.0.0"
|
||||
memoize-one "^5.1.1"
|
||||
prop-types "^15.7.2"
|
||||
react-fast-compare "^3.0.1"
|
||||
|
||||
react-redux@^7.2.0:
|
||||
version "7.2.9"
|
||||
resolved "https://registry.yarnpkg.com/react-redux/-/react-redux-7.2.9.tgz#09488fbb9416a4efe3735b7235055442b042481d"
|
||||
@@ -8091,6 +8524,11 @@ use-latest@^1.2.1:
|
||||
dependencies:
|
||||
use-isomorphic-layout-effect "^1.1.1"
|
||||
|
||||
use-sync-external-store@^1.2.0:
|
||||
version "1.2.0"
|
||||
resolved "https://registry.yarnpkg.com/use-sync-external-store/-/use-sync-external-store-1.2.0.tgz#7dbefd6ef3fe4e767a0cf5d7287aacfb5846928a"
|
||||
integrity sha512-eEgnFxGQ1Ife9bzYs6VLi8/4X6CObHMw9Qr9tPY43iKwsPw8xE8+EFsf/2cFZ5S3esXgpWgtSCtLNS41F+sKPA==
|
||||
|
||||
util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1:
|
||||
version "1.0.2"
|
||||
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
|
||||
|
||||
+54
-4
@@ -1,14 +1,64 @@
|
||||
# OpenAssitant Inference
|
||||
# OpenAssistant Inference
|
||||
|
||||
Preliminary implementation of the inference engine for OpenAssistant.
|
||||
|
||||
## Development Variant 1 (you'll need tmux)
|
||||
## Development Variant 1 (docker compose)
|
||||
|
||||
The services of the inference stack are prefixed with "inference-" in the
|
||||
[unified compose descriptor](../docker-compose.yaml). <br/> Prior to building
|
||||
those, please ensure that you have Docker's new
|
||||
[BuildKit](https://docs.docker.com/build/buildkit/) backend enabled. See the
|
||||
[FAQ](https://projects.laion.ai/Open-Assistant/docs/faq#enable-dockers-buildkit-backend)
|
||||
for more info.
|
||||
|
||||
To build the services, run:
|
||||
|
||||
```shell
|
||||
docker compose --profile inference build
|
||||
```
|
||||
|
||||
Spin up the stack:
|
||||
|
||||
```shell
|
||||
docker compose --profile inference up -d
|
||||
```
|
||||
|
||||
Tail the logs:
|
||||
|
||||
```shell
|
||||
docker compose logs -f \
|
||||
inference-server \
|
||||
inference-worker \
|
||||
inference-text-client \
|
||||
inference-text-generation-server
|
||||
```
|
||||
|
||||
Attach to the text-client, and start chatting:
|
||||
|
||||
```shell
|
||||
docker attach open-assistant-inference-text-client-1
|
||||
```
|
||||
|
||||
> **Note:** In the last step, `open-assistant-inference-text-client-1` refers to
|
||||
> the name of the `text-client` container started in step 2.
|
||||
|
||||
> **Note:** The compose file contains the bind mounts enabling you to develop on
|
||||
> the modules of the inference stack, and the `oasst-shared` package, without
|
||||
> rebuilding.
|
||||
|
||||
> **Note:** You can spin up any number of workers by adjusting the number of
|
||||
> replicas of the `inference-worker` service to your liking.
|
||||
|
||||
> **Note:** Please wait for the `inference-text-generation-server` service to
|
||||
> output `{"message":"Connected"}` before starting to chat.
|
||||
|
||||
## Development Variant 2 (you'll need tmux)
|
||||
|
||||
Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait
|
||||
until the 2nd terminal is ready and says `{"message":"Connected"}` before
|
||||
entering input into the last terminal.
|
||||
|
||||
## Development Variant 2 (you'll need multiple terminals)
|
||||
## Development Variant 3 (you'll need multiple terminals)
|
||||
|
||||
Run a redis container (or use the one of the general docker compose file):
|
||||
|
||||
@@ -36,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server
|
||||
running:
|
||||
|
||||
```bash
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference
|
||||
```
|
||||
|
||||
Run the client:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
tmux new-session -d -s "inference-dev-setup"
|
||||
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd server" C-m
|
||||
tmux send-keys "uvicorn main:app --reload" C-m
|
||||
|
||||
+37
-14
@@ -39,14 +39,6 @@ redisClient = redis.Redis(
|
||||
)
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class CreateChatResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class MessageRequest(pydantic.BaseModel):
|
||||
message: str = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
@@ -57,7 +49,7 @@ class MessageRequest(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
token: str
|
||||
token: inference.TokenResponse
|
||||
|
||||
|
||||
class MessageRequestState(str, enum.Enum):
|
||||
@@ -67,30 +59,61 @@ class MessageRequestState(str, enum.Enum):
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class ChatListEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class ChatEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
conversation: protocol.Conversation
|
||||
|
||||
|
||||
class ListChatsResponse(pydantic.BaseModel):
|
||||
chats: list[ChatListEntry]
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
|
||||
pending_message_request: MessageRequest | None = None
|
||||
message_request_state: MessageRequestState | None = None
|
||||
|
||||
def to_list_entry(self) -> ChatListEntry:
|
||||
return ChatListEntry(id=self.id)
|
||||
|
||||
def to_entry(self) -> ChatEntry:
|
||||
return ChatEntry(id=self.id, conversation=self.conversation)
|
||||
|
||||
|
||||
# TODO: make real database
|
||||
CHATS: dict[str, DbChatEntry] = {}
|
||||
|
||||
|
||||
@app.get("/chat")
|
||||
async def list_chats() -> ListChatsResponse:
|
||||
"""Lists all chats."""
|
||||
logger.info("Listing all chats.")
|
||||
chats = [chat.to_list_entry() for chat in CHATS.values()]
|
||||
return ListChatsResponse(chats=chats)
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
|
||||
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
|
||||
"""Allows a client to create a new chat."""
|
||||
logger.info(f"Received {request}")
|
||||
chat = DbChatEntry()
|
||||
CHATS[chat.id] = chat
|
||||
return CreateChatResponse(id=chat.id)
|
||||
return ChatListEntry(id=chat.id)
|
||||
|
||||
|
||||
@app.get("/chat/{id}")
|
||||
async def get_chat(id: str) -> protocol.Conversation:
|
||||
async def get_chat(id: str) -> ChatEntry:
|
||||
"""Allows a client to get the current state of a chat."""
|
||||
return CHATS[id].conversation
|
||||
return CHATS[id].to_entry()
|
||||
|
||||
|
||||
@app.post("/chat/{id}/message")
|
||||
@@ -143,7 +166,7 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text="".join([d.token for d in result_data[:-1]]),
|
||||
text=response_packet.generated_text.text,
|
||||
is_assistant=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
print(data["token"], end="", flush=True)
|
||||
print(data["token"]["text"], end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@@ -54,24 +54,49 @@ def main(
|
||||
"top_p": work_request.top_p,
|
||||
"temperature": work_request.temperature,
|
||||
"seed": work_request.seed,
|
||||
# "stop": ["\nUser:", "\nAssistant:"], # TODO: make this a bit more workable because it's mutliple tokens
|
||||
},
|
||||
},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError:
|
||||
logger.exception("Failed to get response from inference server")
|
||||
logger.error(f"Response: {response.text}")
|
||||
return
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
for event in client.events():
|
||||
logger.debug(f"Received event: {event}")
|
||||
data = json.loads(event.data)
|
||||
if data["is_end"]:
|
||||
if data["generated_text"]:
|
||||
break
|
||||
intermediate = data["event"]
|
||||
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
|
||||
ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
token = data["token"]
|
||||
ws.send(
|
||||
inference.WorkResponsePacket(
|
||||
token=inference.TokenResponse(
|
||||
text=token["text"],
|
||||
log_prob=token["logprob"],
|
||||
token_id=token["id"],
|
||||
)
|
||||
).json()
|
||||
)
|
||||
ws.send(
|
||||
inference.WorkResponsePacket(
|
||||
is_end=True,
|
||||
generated_text=inference.GeneratedTextResponse(
|
||||
text=data["generated_text"],
|
||||
),
|
||||
).json()
|
||||
)
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
logger.error(f"Connection error: {error}")
|
||||
try:
|
||||
raise error
|
||||
except Exception:
|
||||
logger.exception("Error in websocket")
|
||||
|
||||
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
|
||||
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
Some other reward features we can use
|
||||
|
||||
0. Finish classifcation feature
|
||||
0. Finish classification feature
|
||||
|
||||
1. Summaries from human feedback
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
[] support additional negative samples generated from other models.
|
||||
|
||||
For example we can use galactica-125m to generate a TLDR and assume it was
|
||||
inferior than the human perference one
|
||||
inferior than the human preference one
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -46,6 +46,7 @@ defaults:
|
||||
quantization: false
|
||||
seq2seqmodel: false
|
||||
poly_eps: 1.0
|
||||
fuse_gelu: true
|
||||
|
||||
galactica-125m:
|
||||
learning_rate: 5e-5
|
||||
|
||||
@@ -23,5 +23,5 @@ Issues and TODO:
|
||||
- ideally we can update the config yaml and new dataset will be download from
|
||||
hub
|
||||
|
||||
- one possible idea is we upload the trasform format of these dataset to the
|
||||
- one possible idea is we upload the transform format of these dataset to the
|
||||
OA hub
|
||||
|
||||
@@ -311,7 +311,7 @@ class JokeExplaination(Dataset):
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
joke = data["joke"]
|
||||
explanation = data["explaination"]
|
||||
explanation = data["explanation"]
|
||||
self.pairs.append((joke, explanation))
|
||||
|
||||
if len(question) > 0 and len(answer) > 0:
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from transformers.activations import FastGELUActivation, GELUActivation, NewGELUActivation, QuickGELUActivation
|
||||
|
||||
|
||||
def rsetattr(obj, attr, val):
|
||||
pre, _, post = attr.rpartition(".")
|
||||
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
||||
|
||||
|
||||
def rgetattr(obj, attr, *args):
|
||||
def _getattr(obj, attr):
|
||||
return getattr(obj, attr, *args)
|
||||
|
||||
return functools.reduce(_getattr, [obj] + attr.split("."))
|
||||
|
||||
|
||||
def fuse_gelu(model):
|
||||
@torch.jit.script
|
||||
def gelu_fwd(x):
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
@torch.jit.script
|
||||
def gelu_bwd(g, x):
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff * g
|
||||
|
||||
class _FusedGeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input):
|
||||
ctx.input_tensor = input
|
||||
return gelu_fwd(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.input_tensor
|
||||
tmp = gelu_bwd(grad_output, input)
|
||||
return tmp
|
||||
|
||||
class FusedGelu(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return _FusedGeLUFunction.apply(input)
|
||||
|
||||
fused_gelu_module = FusedGelu()
|
||||
hf_gelu_functions = [GELUActivation, FastGELUActivation, NewGELUActivation, QuickGELUActivation]
|
||||
|
||||
for name, module in model.named_modules():
|
||||
for hf_gelu_function in hf_gelu_functions:
|
||||
if isinstance(module, hf_gelu_function):
|
||||
rsetattr(model, name, fused_gelu_module)
|
||||
|
||||
return model
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
import torch
|
||||
from efficiency_utils import fuse_gelu
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
@@ -180,6 +181,9 @@ if __name__ == "__main__":
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
|
||||
if training_conf.fuse_gelu:
|
||||
model = fuse_gelu(model)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
num_train_epochs=training_conf.num_train_epochs,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, NamedTuple
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
@@ -66,19 +66,64 @@ class PerDatasetSampler(Sampler):
|
||||
return cls(dataset_sizes, dataset_size_per_epoch)
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
def get_dataset_fractions(conf, dataset_sizes):
|
||||
"""Calculate fraction of each dataset to use per epoch when subsampling"""
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
if isinstance(data_config, dict):
|
||||
if "fraction" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["fraction"] <= 0:
|
||||
raise ValueError("Please specify fraction as a value between 0 < fraction <= 1")
|
||||
fractions.append(min(1, data_config[dataset_name]["fraction"]))
|
||||
elif "size" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["size"] > dataset_sizes[i]:
|
||||
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
|
||||
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
|
||||
else:
|
||||
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
|
||||
else:
|
||||
fractions.append(1)
|
||||
return fractions
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
elif "GPT-JT" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
|
||||
elif "codegen" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
elif "pythia" in conf.model_name:
|
||||
tokenizer.add_special_tokens(
|
||||
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
|
||||
)
|
||||
|
||||
class SpecialTokens(NamedTuple):
|
||||
pad_token: str = ""
|
||||
eos_token: str = ""
|
||||
sep_token: str = ""
|
||||
|
||||
|
||||
class TokenizerConfig(NamedTuple):
|
||||
special_tokens: SpecialTokens = {}
|
||||
|
||||
|
||||
TOKENIZER_CONFIGS = {
|
||||
"galactica": TokenizerConfig(special_tokens=SpecialTokens("<pad>", "</s>")),
|
||||
"GPT-JT": TokenizerConfig(special_tokens=SpecialTokens(sep_token="<|extratoken_100|>")),
|
||||
"codegen": TokenizerConfig(special_tokens=SpecialTokens("<|endoftext|>", sep_token="<|endoftext|>")),
|
||||
"pythia": TokenizerConfig(special_tokens=SpecialTokens("<|padding|>", "<|endoftext|>", "<|endoftext|>")),
|
||||
}
|
||||
|
||||
|
||||
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
|
||||
"""Match a partial model name to a tokenizer configuration"""
|
||||
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
|
||||
if not tokenizer_config_matches:
|
||||
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")
|
||||
elif 1 < len(tokenizer_config_matches):
|
||||
raise ValueError(f"Found multiple tokeniser configuration matches for {model_name=}")
|
||||
else:
|
||||
return tokenizer_config_matches[0]
|
||||
|
||||
|
||||
def get_tokenizer(conf) -> transformers.AutoTokenizer:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
tokenizer_config = match_tokenizer_name(conf.model_name)
|
||||
|
||||
if tokenizer_config.special_tokens:
|
||||
if "GPT-JT" in conf.model_name:
|
||||
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
@@ -171,27 +216,6 @@ def get_dataset_name_from_data_config(data_config):
|
||||
return data_config
|
||||
|
||||
|
||||
def get_dataset_fractions(conf, dataset_sizes):
|
||||
"""Calculate fraction of each dataset to use per epoch when subsampling"""
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
if isinstance(data_config, dict):
|
||||
if "fraction" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["fraction"] <= 0:
|
||||
raise ValueError("Please specify fraction as a value between 0 < fraction <= 1")
|
||||
fractions.append(min(1, data_config[dataset_name]["fraction"]))
|
||||
elif "size" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["size"] > dataset_sizes[i]:
|
||||
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
|
||||
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
|
||||
else:
|
||||
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
|
||||
else:
|
||||
fractions.append(1)
|
||||
return fractions
|
||||
|
||||
|
||||
def get_dataset(conf, tokenizer):
|
||||
train_datasets, evals = [], {}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ with (open("paragraphs.pkl", "rb")) as openfile:
|
||||
try:
|
||||
objects.append(pickle.load(openfile))
|
||||
except EOFError:
|
||||
print("Problem laoding your pickle file, using the default array")
|
||||
print("Problem loading your pickle file, using the default array")
|
||||
pickle_fail = True
|
||||
break
|
||||
|
||||
@@ -92,7 +92,7 @@ if pickle_fail:
|
||||
paragraphs = [
|
||||
"Like for me, this thing is like a little side hobby, but it's also one that's deeply fulfilling. So not just from a business perspective, which is not the way I think about it. I just think from a life human perspective, it's I probably wouldn't have this kind of conversation with you off mic, like this long, this deep, this attentive. There's something really fulfilling about these conversations. So what advice would you have for me? What advice do you have for yourself? Oh, have you not introspected this that deeply? Oh, I have advice. I think the first advice I would give to you is I think you should have me on more often. Yeah. Yeah. That's first and foremost. And second is go on your podcast and have a conversation. Well, I would say you come on my podcast when you're ready. Yeah. When you feel like the product that I'm putting out would benefit from your presence and vice versa, not as a favor to a bro, but at the right time.",
|
||||
"Well, we really are looking through a two dimensional screen until it's what we intuit to be a three dimensional world and also inferring dynamic stuff, making it 4D. Anyway, is it possible to visualize some pretty pictures that give us a deeper sense of the truth of reality? I think that we will incrementally be able to do that. I think that, for example, the picture that we have of electrons and photons interacting and scattering, it may have not been possible until Faraday did all of his experiments and then Maxwell wrote down his equations. And we were then sort of forced by his equations to think in a new way. And then when Planck in 1900, desperate to try to solve the problem of black body radiation, what they call the ultraviolet catastrophe where Newton was predicting infinite energies where there weren't infinite energies in black body radiation. And he in desperation proposed packets of energy. Then once you've done that, and then you have an Einstein come along five years later and show how that explains the photoelectric effect.",
|
||||
"But man, I don't know how I would feel about just bacteria everywhere. Well, it would be depressing if it was true. I suppose depressing, I don't think, I don't. I don't know what's more depressing, bacteria everywhere or nothing everywhere. Yes, either of them are chilling. Yeah. But whether it's chilling or not, I don't think should force us to change our view about whether it's real or not. Yes. And what I'm saying may or may not be true. So how would you feel if we discovered life on Mars? Absolutely. It sounds like you would be less excited than some others because you're like, well. What I would be most interested in is how similar to life on Earth it would be. It would actually turn into quite a subtle problem because the likelihood of life having gone to and fro between Mars and the Earth is quite, I wouldn't say high, but it's not low, it's quite feasible. And so if we found life on Mars and it had very similar genetic code, but it was slightly different, most people would interpret that immediately as evidence that they've been transit one way or the other and that it was a common origin of life on Mars or on the Earth and it went one way or the other way.",
|
||||
"But man, I don't know how I would feel about just bacteria everywhere. Well, it would be depressing if it was true. I suppose depressing, I don't think, I don't. I don't know what's more depressing, bacteria everywhere or nothing everywhere. Yes, either of them are chilling. Yeah. But whether it's chilling or not, I don't think should force us to change our view about whether it's real or not. Yes. And what I'm saying may or may not be true. So how would you feel if we discovered life on Mars? Absolutely. It sounds like you would be less excited than some others because you're like, well. What I would be most interested in is how similar to life on Earth it would be. It would actually turn into quite a subtle problem because the likelihood of life having gone to and from between Mars and the Earth is quite, I wouldn't say high, but it's not low, it's quite feasible. And so if we found life on Mars and it had very similar genetic code, but it was slightly different, most people would interpret that immediately as evidence that they've been transit one way or the other and that it was a common origin of life on Mars or on the Earth and it went one way or the other way.",
|
||||
]
|
||||
|
||||
# Make sure no paragraphs are too long for T5. It handles up to 512 tokens context length.
|
||||
|
||||
@@ -86,7 +86,7 @@ Each question and all related answers are on a single line in JSONL format.
|
||||
#### Table/CSV/Parquet Format
|
||||
|
||||
There are a lot more columns left over in the table format. `_q` and `_a` are
|
||||
suffixes indiciating if the column came from the question or answer table as
|
||||
suffixes indicating if the column came from the question or answer table as
|
||||
leftover from a join statement.
|
||||
|
||||
```
|
||||
|
||||
@@ -15,7 +15,7 @@ trained on
|
||||
| multilingual | xlm-roberta-base | Multilingual Toxic Comment Classification |
|
||||
|
||||
Unbiased and original models also have a 'small' version - but since normal
|
||||
models are not memory heavy, and small models perform noticably worse, they are
|
||||
models are not memory heavy, and small models perform noticeably worse, they are
|
||||
only described in the notebook
|
||||
|
||||
## All tests below were ran on a 3090TI
|
||||
|
||||
@@ -7,7 +7,7 @@ this project. Please try and follow this structure as closely as possible. While
|
||||
things will not exactly be the same for each notebook some principles we would
|
||||
like to try ensure are:
|
||||
|
||||
1. Each notebook or collection of related or dependant notebooks should live in
|
||||
1. Each notebook or collection of related or dependent notebooks should live in
|
||||
its own folder.
|
||||
1. Each notebook should have a markdown file with the same name as the notebook
|
||||
(or README.md if it's a single notebook folder) that explains what the
|
||||
|
||||
@@ -68,12 +68,15 @@ class OasstApiClient:
|
||||
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
|
||||
logger.debug(f"response: {response}")
|
||||
|
||||
# If the response is not a 2XX, check to see
|
||||
# if the json has the fields to create an
|
||||
# OasstError.
|
||||
if response.status >= 300:
|
||||
text = await response.text()
|
||||
logger.debug(f"resp text: {text}")
|
||||
data = await response.json()
|
||||
try:
|
||||
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
||||
@@ -114,20 +117,21 @@ class OasstApiClient:
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a task from the backend."""
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
|
||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||
logger.debug(f"RESP {resp}")
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a random task from the backend."""
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
|
||||
"""Send an ACK for a task to the backend."""
|
||||
|
||||
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
|
||||
SERVER_ERROR0 = 500
|
||||
SERVER_ERROR1 = 501
|
||||
|
||||
INVALID_AUTHENTICATION = 600
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
@@ -80,6 +82,7 @@ class OasstErrorCode(IntEnum):
|
||||
USER_NOT_SPECIFIED = 4000
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
USER_HAS_NOT_ACCEPTED_TOS = 4003
|
||||
|
||||
EMOJI_OP_UNSUPPORTED = 5000
|
||||
|
||||
@@ -92,7 +95,7 @@ class OasstError(Exception):
|
||||
http_status_code: HTTPStatus
|
||||
|
||||
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
|
||||
super().__init__(message, error_code, http_status_code) # make exception picklable (fill args member)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.http_status_code = http_status_code
|
||||
|
||||
@@ -13,13 +13,24 @@ class WorkRequest(pydantic.BaseModel):
|
||||
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1))
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**31 - 1))
|
||||
do_sample: bool = True
|
||||
top_k: int = 50
|
||||
top_p: float = 0.9
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
class TokenResponse(pydantic.BaseModel):
|
||||
text: str
|
||||
log_prob: float
|
||||
token_id: int
|
||||
|
||||
|
||||
class GeneratedTextResponse(pydantic.BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class WorkResponsePacket(pydantic.BaseModel):
|
||||
token: str | None = None
|
||||
token: TokenResponse | None = None
|
||||
generated_text: GeneratedTextResponse | None = None
|
||||
is_end: bool = False
|
||||
|
||||
@@ -29,6 +29,17 @@ class User(BaseModel):
|
||||
auth_method: Literal["discord", "local", "system"]
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
id: UUID
|
||||
provider: str
|
||||
provider_account_id: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class FrontEndUser(User):
|
||||
user_id: UUID
|
||||
enabled: bool
|
||||
@@ -39,6 +50,7 @@ class FrontEndUser(User):
|
||||
streak_days: Optional[int] = None
|
||||
streak_last_day_date: Optional[datetime] = None
|
||||
last_activity_date: Optional[datetime] = None
|
||||
tos_acceptance_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
@@ -64,6 +76,7 @@ class ConversationMessage(BaseModel):
|
||||
is_assistant: bool
|
||||
emojis: Optional[dict[str, int]] = None
|
||||
user_emojis: Optional[list[str]] = None
|
||||
user_is_author: Optional[bool] = None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
@@ -89,6 +102,12 @@ class Message(ConversationMessage):
|
||||
created_date: Optional[datetime]
|
||||
review_result: Optional[bool]
|
||||
review_count: Optional[int]
|
||||
deleted: Optional[bool]
|
||||
synthetic: Optional[bool]
|
||||
model_name: Optional[str]
|
||||
message_tree_id: Optional[UUID]
|
||||
ranking_count: Optional[int]
|
||||
rank: Optional[int]
|
||||
|
||||
|
||||
class MessagePage(PageResult):
|
||||
@@ -252,22 +271,21 @@ class AbstractLabelTask(Task):
|
||||
mode: Optional[LabelTaskMode]
|
||||
disposition: Optional[LabelTaskDisposition]
|
||||
labels: Optional[list[LabelDescription]]
|
||||
conversation: Conversation # the conversation so far (labeling -> last message)
|
||||
|
||||
|
||||
class LabelInitialPromptTask(AbstractLabelTask):
|
||||
"""A task to label an initial prompt."""
|
||||
|
||||
type: Literal["label_initial_prompt"] = "label_initial_prompt"
|
||||
prompt: str
|
||||
prompt: str | None = Field(None, deprecated=True, description="deprecated, use `prompt_message`")
|
||||
|
||||
|
||||
class LabelConversationReplyTask(AbstractLabelTask):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far (new: including the reply message)
|
||||
reply_message: Optional[ConversationMessage]
|
||||
reply: str
|
||||
reply: str | None = Field(None, deprecated=True, description="deprecated, use last message of `conversation`")
|
||||
|
||||
|
||||
class LabelPrompterReplyTask(LabelConversationReplyTask):
|
||||
@@ -469,6 +487,47 @@ class LeaderboardStats(BaseModel):
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
class TrollScore(BaseModel):
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
highlighted: bool = False
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
last_activity_date: Optional[datetime]
|
||||
|
||||
troll_score: int = 0
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: Optional[datetime]
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: Optional[float] = None
|
||||
humor: Optional[float] = None
|
||||
toxicity: Optional[float] = None
|
||||
violence: Optional[float] = None
|
||||
helpfulness: Optional[float] = None
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
|
||||
class TrollboardStats(BaseModel):
|
||||
time_frame: str
|
||||
last_updated: datetime
|
||||
trollboard: List[TrollScore]
|
||||
|
||||
|
||||
class OasstErrorResponse(BaseModel):
|
||||
"""The format of an error response from the OASST API."""
|
||||
|
||||
@@ -489,6 +548,11 @@ class EmojiCode(str, enum.Enum):
|
||||
poop = "poop" # 💩
|
||||
skull = "skull" # 💀
|
||||
|
||||
# skip task system uses special emoji codes
|
||||
skip_reply = "_skip_reply"
|
||||
skip_ranking = "_skip_ranking"
|
||||
skip_labeling = "_skip_labeling"
|
||||
|
||||
|
||||
class EmojiOp(str, enum.Enum):
|
||||
togggle = "toggle"
|
||||
@@ -500,3 +564,10 @@ class MessageEmojiRequest(BaseModel):
|
||||
user: User
|
||||
op: EmojiOp = EmojiOp.togggle
|
||||
emoji: EmojiCode
|
||||
|
||||
|
||||
class CreateFrontendUserRequest(User):
|
||||
show_on_leaderboard: bool = True
|
||||
enabled: bool = True
|
||||
tos_acceptance: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
# **Datasets**
|
||||
|
||||
This folder contains datasets loading scripts that are used to train
|
||||
OpenAssistant. The current list of datasets can be found
|
||||
[here](https://docs.google.com/spreadsheets/d/1NYYa6vHiRnk5kwnyYaCT0cBO62--Tm3w4ihdBtp4ISk).
|
||||
|
||||
## **Adding a New Dataset**
|
||||
|
||||
To add a new dataset to OpenAssistant, follow these steps:
|
||||
|
||||
1. **Create an issue**: Create a new
|
||||
[issue](https://github.com/LAION-AI/Open-Assistant/issues/new) and describe
|
||||
your proposal for the new dataset.
|
||||
|
||||
2. **Create a dataset on HuggingFace**: Create a dataset on
|
||||
[HuggingFace](https://huggingface.co). See
|
||||
[below](#creating-a-dataset-on-huggingface) for more details.
|
||||
|
||||
3. **Make a pull request**: Add a new dataset loading script to this folder and
|
||||
link the issue in the pull request description. For more information, see
|
||||
[below](#making-a-pull-request).
|
||||
|
||||
## **Creating a Dataset on HuggingFace**
|
||||
|
||||
To create a new dataset on HuggingFace, follow these steps:
|
||||
|
||||
#### 1. Convert your dataset file(s) to the Parquet format using the [pandas](https://pandas.pydata.org/) library:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create a pandas dataframe from your dataset file(s)
|
||||
df = pd.read_json(...) # or any other way
|
||||
|
||||
# Save the file in the Parquet format
|
||||
df.to_parquet("dataset.parquet", row_group_size=100, engine="pyarrow")
|
||||
```
|
||||
|
||||
#### 2. Install HuggingFace CLI
|
||||
|
||||
```bash
|
||||
pip install huggingface-cli
|
||||
```
|
||||
|
||||
#### 3. Log in to HuggingFace
|
||||
|
||||
Use your [access token](https://huggingface.co/docs/hub/security-tokens) to
|
||||
login:
|
||||
|
||||
- Via terminal
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
- in Jupyter notebook
|
||||
|
||||
```python
|
||||
from huggingface_hub import notebook_login
|
||||
notebook_login()
|
||||
```
|
||||
|
||||
#### 4. Push the Parquet file to HuggingFace using the following code:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
ds = Dataset.from_parquet("dataset.parquet")
|
||||
ds.push_to_hub("your_huggingface_name/dataset_name")
|
||||
```
|
||||
|
||||
#### 5. Update the `README.md` file
|
||||
|
||||
Update the `README.md` file of your dataset by visiting this link:
|
||||
https://huggingface.co/datasets/your_huggingface_name/dataset_name/edit/main/README.md
|
||||
(paste your HuggingFace name and dataset)
|
||||
|
||||
## **Making a Pull Request**
|
||||
|
||||
#### 1. Fork this repository
|
||||
|
||||
#### 2. Create a new branch in your fork
|
||||
|
||||
#### 3. Add your dataset to the repository
|
||||
|
||||
- Create a folder with the name of your dataset.
|
||||
- Add a loading script that loads your dataset from HuggingFace, for example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = load_dataset("your_huggingface_name/dataset_name")
|
||||
print(ds["train"])
|
||||
```
|
||||
|
||||
- Optionally, add any other files that describe your dataset and its creation,
|
||||
such as a README, notebooks, scrapers, etc.
|
||||
|
||||
#### 4. Stage your changes and run the pre-commit hook
|
||||
|
||||
```bash
|
||||
pre-commit run
|
||||
```
|
||||
|
||||
#### 5. Submit a pull request
|
||||
|
||||
- Submit a pull request and include a link to the issue it resolves in the
|
||||
description, for example: `Resolves #123`
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_ALLOW_SELF_RANKING=True
|
||||
export DEBUG_ALLOW_DUPLICATE_TASKS=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
export RATE_LIMIT=0
|
||||
export DEBUG_USE_SEED_DATA_PATH=$parent_path/../../backend/test_data/generic/test_generic_data.json
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
popd
|
||||
@@ -7,6 +7,7 @@ pushd "$parent_path/../../backend"
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_ALLOW_SELF_RANKING=True
|
||||
export DEBUG_ALLOW_DUPLICATE_TASKS=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ conversation, or at least as a prompt with replies.
|
||||
guarantee of the quality of the tweets.
|
||||
- The tweet quality is the other major issue. We can get conversations through
|
||||
the currently made scripts, but they most likely don't match a useful
|
||||
instruction -> fulfilment. We are trying to filter the tweets through various
|
||||
instruction -> fulfillment. We are trying to filter the tweets through various
|
||||
means such as matching useful hashtags, or by using cosine similarity against
|
||||
known instructions.
|
||||
- The modern Twitter API has conversation_id as a field which can be a way to
|
||||
@@ -68,7 +68,7 @@ conversation, or at least as a prompt with replies.
|
||||
## TODO
|
||||
|
||||
- Write scripts to filter existing conversations into useful instructions ->
|
||||
fulfilment with hashtags or cosine similarity.
|
||||
fulfillment with hashtags or cosine similarity.
|
||||
- Train model to detect if text is a suitable instruction. This could then be
|
||||
run through the conversations (or full tweet dump) to simplify the process.
|
||||
Related to issue #143.
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
# This assumes data downloaded from https://archive.org/details/twitterstream
|
||||
# and that the internal .tar files are extracted locally.
|
||||
# They are large files so using something like 7Zip or WinRar migth be easier
|
||||
# They are large files so using something like 7Zip or WinRar might be easier
|
||||
# than putting all of it in scripts, but it is a possibility.
|
||||
|
||||
# I often work in notebooks. If you encounter any issue, please reach out to let me know.
|
||||
|
||||
@@ -94,7 +94,7 @@ class EssayReviser(DataAugmenter):
|
||||
def parse_single(self, essay):
|
||||
instructions = []
|
||||
|
||||
# Make stucture error (shuffle one paragraph with another)
|
||||
# Make structure error (shuffle one paragraph with another)
|
||||
essay_paragraphs = essay.split("\n\n") # Splitting a String by newline character (\n)
|
||||
|
||||
rand1 = random.randint(0, len(essay_paragraphs) - 1)
|
||||
@@ -424,7 +424,7 @@ class CodeInstructor(DataAugmenter):
|
||||
|
||||
|
||||
def recognize_entities(text, model, n=4, person="ignore"):
|
||||
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
|
||||
"""Given a text and a model for entity recognition, return the most occurring entities in the text as a string"""
|
||||
doc = model(text)
|
||||
if person == "ignore":
|
||||
ents = Counter([ent.text.strip() for ent in list(doc.ents) if len(ent.text.strip()) >= 5])
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from collections import defaultdict
|
||||
from glob import glob
|
||||
from json import load
|
||||
from os import path
|
||||
|
||||
ALL_PATH = "../../website/public/locales/**/*.json"
|
||||
DIR = path.dirname(__file__)
|
||||
EN_PATH = "../../website/public/locales/en/*.json"
|
||||
|
||||
|
||||
def get_not_translated(en_json, translation_json, parent_key=None):
|
||||
not_translated = []
|
||||
for key in en_json.keys():
|
||||
if key in translation_json and translation_json[key] == en_json[key]:
|
||||
not_translated.append(("{0}.{1}".format(parent_key, key) if parent_key else key))
|
||||
elif isinstance(en_json[key], dict):
|
||||
not_translated.extend(get_not_translated(en_json[key], translation_json[key], key))
|
||||
return not_translated
|
||||
|
||||
|
||||
def get_missing(en_json, translation_json):
|
||||
return [key for key in en_json.keys() if key not in translation_json]
|
||||
|
||||
|
||||
def print_result(missing, not_translated, file):
|
||||
if len(missing):
|
||||
print("[{0}] - {1}\tmissing: {2}".format(path.basename(path.dirname(file)), path.basename(file), missing))
|
||||
if len(not_translated):
|
||||
print(
|
||||
"[{0}] - {1}\tpotentially untranslated: {2}".format(
|
||||
path.basename(path.dirname(file)), path.basename(file), not_translated
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def audit(file, en_file):
|
||||
en_json = load(open(en_file))
|
||||
translation_json = load(open(file))
|
||||
return (get_missing(en_json, translation_json), get_not_translated(en_json, translation_json), file)
|
||||
|
||||
|
||||
def main():
|
||||
per_language_dict = defaultdict(list)
|
||||
for en_file in glob(path.join(DIR, EN_PATH)):
|
||||
for file in glob(path.join(DIR, ALL_PATH)):
|
||||
if path.basename(en_file) == path.basename(file) and file != en_file:
|
||||
file_info = audit(file, en_file)
|
||||
lang = path.basename(path.dirname(file))
|
||||
per_language_dict[lang].append(file_info)
|
||||
for results in per_language_dict.values():
|
||||
list(map(lambda args: print_result(*args), results))
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -66,7 +66,7 @@ def get_winner(pairs):
|
||||
|
||||
def get_ranking(pairs):
|
||||
"""
|
||||
Abuses concordance property to get a (not necessarily unqiue) ranking.
|
||||
Abuses concordance property to get a (not necessarily unique) ranking.
|
||||
The lack of uniqueness is due to the potential existence of multiple
|
||||
equally ranked winners. We have to pick one, which is where
|
||||
the non-uniqueness comes from
|
||||
|
||||
@@ -58,7 +58,7 @@ def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Vote
|
||||
after that voter cast a vote on a question.
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information
|
||||
|
||||
The consensus is the array of all votes cast by all voters for that question
|
||||
We then update the voter data using the new information
|
||||
@@ -88,7 +88,7 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
This function returns the gain of points for a given prompt's votes
|
||||
|
||||
In contrast to the other score updating functions, we can run this online as new votes come in.
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information.
|
||||
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information.
|
||||
|
||||
|
||||
Parameters:
|
||||
@@ -122,7 +122,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr
|
||||
This function returns the gain of points for a given ranking's votes
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information
|
||||
|
||||
we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings
|
||||
we use this over spearman correlation since:
|
||||
|
||||
@@ -56,7 +56,7 @@ def next_answer_task(possible_prompts, answers_per_prompt):
|
||||
This helps to not have too much close-to-finished prompts in the active set.
|
||||
|
||||
Parameters:
|
||||
possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have.
|
||||
possible_prompts (dict[prompt_id, num_answers]): a dictionary containing all open prompts and the number of answers these prompts currently have.
|
||||
answers_per_prompt (int): number of answers we per prompt to target
|
||||
Returns:
|
||||
prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer
|
||||
|
||||
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request = dict(USER)
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
+204
-189
@@ -6,12 +6,10 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
import typer
|
||||
from faker import Faker
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# debug constants
|
||||
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
|
||||
fake = Faker()
|
||||
|
||||
|
||||
def _random_message_id():
|
||||
@@ -26,7 +24,9 @@ def _render_message(message: dict) -> str:
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
def main(
|
||||
backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234", random_users: int = 1, task_per_user: int = 10
|
||||
):
|
||||
"""automates tasks"""
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
@@ -50,204 +50,219 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
print(shuffled)
|
||||
return ranks
|
||||
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
q = 0
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
print(task)
|
||||
for i in range(int(random_users)):
|
||||
name = fake.name()
|
||||
USER = {"id": name, "display_name": name, "auth_method": "local"}
|
||||
|
||||
match (task["type"]):
|
||||
case "initial_prompt":
|
||||
typer.echo("Please provide an initial prompt to the assistant.")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
create_user_request = dict(USER)
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
q = 0
|
||||
|
||||
prompt = gen_random_text()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
|
||||
case "label_initial_prompt":
|
||||
typer.echo("Label the following prompt:")
|
||||
typer.echo(task["prompt"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
print(task)
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
mandatory_labels = task["mandatory_labels"]
|
||||
match (task["type"]):
|
||||
case "initial_prompt":
|
||||
typer.echo("Please provide an initial prompt to the assistant.")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
for l in mandatory_labels:
|
||||
if l not in labels:
|
||||
labels.append(l)
|
||||
labels_dict = {label: random.random() for label in valid_labels}
|
||||
if random.random() < 0.9:
|
||||
labels_dict["spam"] = 0
|
||||
labels_dict["lang_mismatch"] = 0
|
||||
prompt = gen_random_text()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "prompter_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "label_initial_prompt":
|
||||
typer.echo("Label the following prompt:")
|
||||
typer.echo(task["prompt"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
case "assistant_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
valid_labels = task["valid_labels"]
|
||||
mandatory_labels = task["mandatory_labels"]
|
||||
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["replies"])
|
||||
print(ranking)
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
for l in mandatory_labels:
|
||||
if l not in labels:
|
||||
labels.append(l)
|
||||
labels_dict = {label: random.random() for label in valid_labels}
|
||||
if random.random() < 0.9:
|
||||
labels_dict["spam"] = 0
|
||||
labels_dict["lang_mismatch"] = 0
|
||||
|
||||
case "rank_initial_prompts":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["prompots"])
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "prompter_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_prompter_reply" | "label_assistant_reply":
|
||||
# acknowledge task
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
case "assistant_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
typer.echo("Label the following reply:")
|
||||
typer.echo(task["reply"])
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
valid_labels = task["valid_labels"]
|
||||
mandatory_labels = task["mandatory_labels"]
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["replies"])
|
||||
print(ranking)
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
for l in mandatory_labels:
|
||||
if l not in labels:
|
||||
labels.append(l)
|
||||
labels_dict = {label: random.random() for label in valid_labels}
|
||||
if random.random() < 0.9:
|
||||
labels_dict["spam"] = 0
|
||||
labels_dict["lang_mismatch"] = 0
|
||||
case "rank_initial_prompts":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["prompots"])
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
# rerun with new task slected from above cases
|
||||
# add a new task
|
||||
q += 1
|
||||
if q == 10:
|
||||
case "label_prompter_reply" | "label_assistant_reply":
|
||||
# acknowledge task
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
|
||||
typer.echo("Label the following reply:")
|
||||
typer.echo(task["reply"])
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
valid_labels = task["valid_labels"]
|
||||
mandatory_labels = task["mandatory_labels"]
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
for l in mandatory_labels:
|
||||
if l not in labels:
|
||||
labels.append(l)
|
||||
labels_dict = {label: random.random() for label in valid_labels}
|
||||
if random.random() < 0.9:
|
||||
labels_dict["spam"] = 0
|
||||
labels_dict["lang_mismatch"] = 0
|
||||
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
break
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
#
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
# rerun with new task slected from above cases
|
||||
# rerun with new task selected from above cases
|
||||
# add a new task
|
||||
q += 1
|
||||
if q == task_per_user:
|
||||
typer.echo("Task done!")
|
||||
break
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
#
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
# rerun with new task selected from above cases
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
faker==16.6.1
|
||||
requests==2.28.1
|
||||
typer==0.7.0
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user