Merge branch 'main' into sft-data-sampling

This commit is contained in:
sanagnos
2023-02-09 09:19:17 +01:00
committed by GitHub
332 changed files with 10207 additions and 2327 deletions
+17
View File
@@ -38,14 +38,31 @@ jobs:
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
MAX_INITIAL_PROMPT_REVIEW: ${{ vars.MAX_INITIAL_PROMPT_REVIEW }}
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
LONELY_CHILDREN_COUNT: ${{ vars.LONELY_CHILDREN_COUNT }}
P_LONELY_CHILD_EXTENSION: ${{ vars.P_LONELY_CHILD_EXTENSION }}
P_ACTIVATE_BACKLOG_TREE: ${{ vars.P_ACTIVATE_BACKLOG_TREE }}
MIN_ACTIVE_RANKINGS_PER_LANG: ${{ vars.MIN_ACTIVE_RANKINGS_PER_LANG }}
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }}
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }}
STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }}
STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }}
STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }}
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
${{ secrets.WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY:
${{ secrets.WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
LOGURU_LEVEL: ${{ vars.LOGURU_LEVEL }}
MAINTENANCE_MODE: ${{ vars.MAINTENANCE_MODE }}
BACKEND_URL: ${{ vars.BACKEND_URL }}
steps:
- name: Checkout
uses: actions/checkout@v2
+2
View File
@@ -21,6 +21,8 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: "**/requirements*.txt"
- uses: pre-commit/action@v3.0.0
- name: Post PR comment on failure
if: failure() && github.event_name == 'pull_request_target'
+16
View File
@@ -0,0 +1,16 @@
name: Deploy to prod2
on:
push:
branches:
- production2
jobs:
deploy-to-prod:
uses: ./.github/workflows/deploy-to-node.yaml
secrets: inherit
with:
stack-name: production2
image-tag: ${{ vars.PROD_IMAGE_TAG }}
backend-port: 8280
website-port: 3200
+2
View File
@@ -23,6 +23,8 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: "**/requirements*.txt"
- uses: actions/setup-node@v3
with:
node-version: 16
+2 -1
View File
@@ -1,4 +1,5 @@
{
"python.formatting.provider": "black",
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"]
"python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"],
"prettier.singleQuote": false
}
+7 -7
View File
@@ -13,12 +13,12 @@ dedicated, but not as active, channel.
### Taking on Tasks
We have a growing task list
[of issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue
that appeals to you and make a comment that you'd like to work on it. Include in
your comment a brief description of how you'll solve the problem and if there
are any open questions you want to discuss. Once a project coordinator has
assigned the issue to you, start working on it.
We have a growing task list of
[issues](https://github.com/LAION-AI/Open-Assistant/issues). Find an issue that
appeals to you and make a comment that you'd like to work on it. Include in your
comment a brief description of how you'll solve the problem and if there are any
open questions you want to discuss. Once a project coordinator has assigned the
issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in Discord
and someone can help clarify the issue with more detail.
@@ -140,4 +140,4 @@ automatically deploy the built release to the dev machine.
### Contribute a Dataset
See
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/openassistant/datasets/README.md)
+26 -3
View File
@@ -14,6 +14,12 @@
</div>
# Here is our website to collect data:
[open-assistant.io](https://open-assistant.io)
(project documentation lives [here](https://laion-ai.github.io/Open-Assistant/))
# Table of Contents
- [What is Open Assistant?](#what-is-open-assistant)
@@ -38,9 +44,22 @@ improving language itself.
## Do you want to try it out?
If you are interested in taking a look at the current state of the project, you
### Contributing to Data Collection
The data collection frontend is now live [here](https://open-assistant.io/). Log
in and start taking on tasks! We want to collect a high volume of quality data.
By submitting, ranking, and labelling model prompts and responses you will be
directly helping to improve the capabilities of Open Assistant.
### Running Locally
**You do not need to run the project locally unless you are contributing to the
development process. The website link above will take you to the public website
where you can use the data collection app.**
If you would like to run the data collection app locally for development, you
can set up an entire stack needed to run **Open-Assistant**, including the
website, backend, and associated dependent services.
website, backend, and associated dependent services, with Docker.
##### To start the demo, run this in the root directory of the repository:
@@ -51,6 +70,10 @@ docker compose up --build
Then, navigate to `http://localhost:3000` (It may take some time to boot up) and
interact with the website.
> **Note:** If an issue occurs with the build, please head to the
> [FAQ](https://projects.laion.ai/Open-Assistant/docs/faq) and check out the
> entries about Docker.
> **Note:** When logging in via email, navigate to `http://localhost:1080` to
> get the magic email login link.
@@ -65,7 +88,7 @@ interact with the website.
## The Plan
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the InstructGPT paper.
##### We want to get to an initial MVP as fast as possible, by following the 3-steps outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155).
1. Collect high-quality human generated Instruction-Fulfillment samples
(prompt + response), goal >50k. We design a crowdsourced process to collect
+1 -1
View File
@@ -1,6 +1,6 @@
To test the ansible playbook on localhost run
`ansible-playbook -i test.inventory.ini dev.yaml`.\
In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\
In case you're missing the ansible docker dependency install it with `ansible-galaxy collection install community.docker`.\
Point Redis Insights to the Redis database by visiting localhost:8001 in a
browser and select "I already have a database" followed by "Connect to a Redis
Database".\
+36 -1
View File
@@ -113,6 +113,9 @@
TREE_MANAGER__MAX_ACTIVE_TREES:
"{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') |
default('10', true) }}"
TREE_MANAGER__MAX_INITIAL_PROMPT_REVIEW:
"{{ lookup('ansible.builtin.env', 'MAX_INITIAL_PROMPT_REVIEW') |
default('100', true) }}"
TREE_MANAGER__MAX_TREE_DEPTH:
"{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5',
true) }}"
@@ -122,6 +125,21 @@
TREE_MANAGER__MAX_CHILDREN_COUNT:
"{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') |
default('3', true) }}"
TREE_MANAGER__LONELY_CHILDREN_COUNT:
"{{ lookup('ansible.builtin.env', 'LONELY_CHILDREN_COUNT') |
default('2', true) }}"
TREE_MANAGER__P_LONELY_CHILD_EXTENSION:
"{{ lookup('ansible.builtin.env', 'P_LONELY_CHILD_EXTENSION') |
default('0.75', true) }}"
TREE_MANAGER__P_ACTIVATE_BACKLOG_TREE:
"{{ lookup('ansible.builtin.env', 'P_ACTIVATE_BACKLOG_TREE') |
default('0.1', true) }}"
TREE_MANAGER__MIN_ACTIVE_RANKINGS_PER_LANG:
"{{ lookup('ansible.builtin.env', 'MIN_ACTIVE_RANKINGS_PER_LANG') |
default('0', true) }}"
MESSAGE_SIZE_LIMIT:
"{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') |
default('2000', true) }}"
USER_STATS_INTERVAL_DAY:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') |
default('5', true) }}"
@@ -134,6 +152,9 @@
USER_STATS_INTERVAL_TOTAL:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_TOTAL') |
default('240', true) }}"
LOGURU_LEVEL:
"{{ lookup('ansible.builtin.env', 'LOGURU_LEVEL') | default('INFO',
true) }}"
ports:
- "{{ backend_port }}:8080"
@@ -165,13 +186,27 @@
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PORT') }}"
EMAIL_SERVER_USER:
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
FASTAPI_URL: "http://oasst-{{ stack_name }}-backend:8080"
FASTAPI_URL: "{{ lookup('ansible.builtin.env', 'BACKEND_URL') }}"
FASTAPI_KEY: "{{ web_api_key }}"
NEXTAUTH_SECRET:
"{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
NEXTAUTH_URL:
"{{ 'https://open-assistant.io/' if stack_name == 'production' else
('https://web.' + stack_name + '.open-assistant.io/') }}"
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
"{{ lookup('ansible.builtin.env',
'WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY') }}"
CLOUDFLARE_CAPTCHA_SECRET_KEY:
"{{ lookup('ansible.builtin.env',
'WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY') }}"
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
"{{ lookup('ansible.builtin.env',
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA') }}"
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
"{{ lookup('ansible.builtin.env',
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN') }}"
MAINTENANCE_MODE:
"{{ lookup('ansible.builtin.env', 'MAINTENANCE_MODE') }}"
ports:
- "{{ website_port }}:3000"
command: bash wait-for-postgres.sh node server.js
@@ -36,6 +36,7 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user_stats", "streak_days")
op.drop_column("user_stats", "streak_last_day_date")
op.drop_column("user", "streak_days")
op.drop_column("user", "streak_last_day_date")
op.drop_column("user", "last_activity_date")
# ### end Alembic commands ###
@@ -0,0 +1,34 @@
"""add tos_acceptance_date to user
Revision ID: 55361f323d12
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 00:22:08.280251
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55361f323d12"
down_revision = "f60958968ff8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
op.drop_column("user_stats", "streak_days")
op.drop_column("user_stats", "streak_last_day_date")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
)
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column("user", "tos_acceptance_date")
# ### end Alembic commands ###
@@ -0,0 +1,27 @@
"""add won_prompt_lottery_date to mts
Revision ID: f60958968ff8
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 10:10:38.301707
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "f60958968ff8"
down_revision = "7b8f0011e0b0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message_tree_state", sa.Column("won_prompt_lottery_date", sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_tree_state", "won_prompt_lottery_date")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""add skip bool & skip_reason to task
Revision ID: 9e7ec4a9e3f2
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 21:46:49.971052
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "9e7ec4a9e3f2"
down_revision = "55361f323d12"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("task", sa.Column("skipped", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("task", sa.Column("skip_reason", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("task", "skip_reason")
op.drop_column("task", "skipped")
# ### end Alembic commands ###
@@ -0,0 +1,59 @@
"""add troll_stats
Revision ID: 4d7e0b0ebe84
Revises: 9e7ec4a9e3f2
Create Date: 2023-02-02 15:44:12.647260
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4d7e0b0ebe84"
down_revision = "9e7ec4a9e3f2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"troll_stats",
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("troll_score", sa.Integer(), nullable=False),
sa.Column("rank", sa.Integer(), nullable=True),
sa.Column("red_flags", sa.Integer(), nullable=False),
sa.Column("upvotes", sa.Integer(), nullable=False),
sa.Column("downvotes", sa.Integer(), nullable=False),
sa.Column("spam_prompts", sa.Integer(), nullable=False),
sa.Column("quality", sa.Float(), nullable=True),
sa.Column("humor", sa.Float(), nullable=True),
sa.Column("toxicity", sa.Float(), nullable=True),
sa.Column("violence", sa.Float(), nullable=True),
sa.Column("helpfulness", sa.Float(), nullable=True),
sa.Column("spam", sa.Integer(), nullable=False),
sa.Column("lang_mismach", sa.Integer(), nullable=False),
sa.Column("not_appropriate", sa.Integer(), nullable=False),
sa.Column("pii", sa.Integer(), nullable=False),
sa.Column("hate_speech", sa.Integer(), nullable=False),
sa.Column("sexual_content", sa.Integer(), nullable=False),
sa.Column("political_content", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "time_frame"),
)
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
op.drop_table("troll_stats")
# ### end Alembic commands ###
@@ -0,0 +1,39 @@
"""Add Account table
Revision ID: 8c8241d1f973
Revises: 4d7e0b0ebe84
Create Date: 2023-01-30 15:10:58.776315
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8c8241d1f973"
down_revision = "4d7e0b0ebe84"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"account",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("provider", "account", ["provider_account_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("provider", table_name="account")
op.drop_table("account")
# ### end Alembic commands ###
@@ -0,0 +1,41 @@
"""Added new table for flagged messages
Revision ID: caee1e8ee0bc
Revises: 8c8241d1f973
Create Date: 2023-02-07 19:22:12.696257
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "caee1e8ee0bc"
down_revision = "8c8241d1f973"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"flagged_message",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("processed", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("message_id"),
)
op.create_index(op.f("ix_flagged_message_created_date"), "flagged_message", ["created_date"], unique=False)
op.create_index(op.f("ix_flagged_message_processed"), "flagged_message", ["processed"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_flagged_message_processed"), table_name="flagged_message")
op.drop_index(op.f("ix_flagged_message_created_date"), table_name="flagged_message")
op.drop_table("flagged_message")
# ### end Alembic commands ###
+218
View File
@@ -0,0 +1,218 @@
import argparse
from pathlib import Path
from typing import List, Optional
from uuid import UUID
from loguru import logger
from oasst_backend.database import engine
from oasst_backend.models import Message, MessageTreeState
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.utils import tree_export
from sqlmodel import Session, not_
def fetch_tree_ids(
db: Session,
state_filter: Optional[TreeState] = None,
lang: Optional[str] = None,
) -> list[tuple[UUID, TreeState]]:
tree_qry = (
db.query(MessageTreeState)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
)
if lang is not None:
tree_qry = tree_qry.filter(Message.lang == lang)
if state_filter:
tree_qry = tree_qry.filter(MessageTreeState.state == state_filter)
return [(tree.message_tree_id, tree.state) for tree in tree_qry]
def fetch_tree_messages(
db: Session,
message_tree_id: Optional[UUID] = None,
user_id: Optional[UUID] = None,
deleted: bool = None,
prompts_only: bool = False,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> List[Message]:
qry = db.query(Message)
if message_tree_id:
qry = qry.filter(Message.message_tree_id == message_tree_id)
if user_id:
qry = qry.filter(Message.user_id == user_id)
if deleted is not None:
qry = qry.filter(Message.deleted == deleted)
if prompts_only:
qry = qry.filter(Message.parent_id.is_(None))
if lang:
qry = qry.filter(Message.lang == lang)
if review_result is False:
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
elif review_result is True:
qry = qry.filter(Message.review_result)
return qry.all()
def export_trees(
db: Session,
export_file: Optional[Path] = None,
use_compression: bool = False,
deleted: bool = False,
user_id: Optional[UUID] = None,
prompts_only: bool = False,
state_filter: Optional[TreeState] = None,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> None:
trees_to_export: List[tree_export.ExportMessageTree] = []
if user_id or review_result is False:
messages = fetch_tree_messages(
db,
user_id=user_id,
deleted=deleted,
prompts_only=prompts_only,
lang=lang,
review_result=review_result,
)
tree_export.write_messages_to_file(export_file, messages, use_compression)
else:
message_tree_ids = fetch_tree_ids(db, state_filter, lang=lang)
message_trees = [
fetch_tree_messages(
db,
message_tree_id=tree_id,
deleted=deleted,
prompts_only=prompts_only,
lang=None,
review_result=review_result,
)
for (tree_id, _) in message_tree_ids
]
# when exporting only-deleted we don't have a porper tree structure, export as list
if deleted is True:
messages = [m for t in message_trees for m in t]
tree_export.write_messages_to_file(export_file, messages, use_compression)
else:
for (message_tree_id, message_tree_state), message_tree in zip(message_tree_ids, message_trees):
t = tree_export.build_export_tree(message_tree_id, message_tree_state, message_tree)
if prompts_only:
t.prompt.replies = None
trees_to_export.append(t)
tree_export.write_trees_to_file(export_file, trees_to_export, use_compression)
def validate_args(args):
if args.deleted_only:
args.include_deleted = True
args.use_compression = args.export_file is not None and ".gz" in args.export_file
if args.state and args.user is not None:
raise ValueError("Cannot use --state when specifying a user ID")
if args.export_file is None:
logger.warning("No export file provided, output will be sent to STDOUT")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--export-file",
type=str,
help="Name of file to export trees to. If not provided, output will be sent to STDOUT",
)
parser.add_argument(
"--include-deleted",
action="store_true",
help="Include deleted messages in export",
)
parser.add_argument(
"--deleted-only",
action="store_true",
help="Export only deleted messages (implies --include-deleted)",
)
parser.add_argument(
"--include-spam",
action="store_true",
help="Export only messages with negative review result.",
)
parser.add_argument(
"--spam-only",
action="store_true",
help="Export only messages with negative review result (implies --include-spam).",
)
parser.add_argument(
"--user",
type=str,
help="Only export trees involving the user with the specified ID. Incompatible with --state.",
)
parser.add_argument(
"--state",
type=str,
help="all|prompt_lottery_waiting|growing|ready_for_export|aborted_low_grade|halted_by_moderator|backlog_ranking",
)
parser.add_argument(
"--lang",
type=str,
help="Filter message trees by language code (BCP 47)",
)
parser.add_argument(
"--prompts-only",
action="store_true",
help="Export a list of initial prompt messages",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
validate_args(args)
state_filter: Optional[TreeState] = None
if args.state is None:
state_filter = TreeState.READY_FOR_EXPORT
elif args.state != "all":
state_filter = TreeState(args.state)
deleted: Optional[bool] = False
if args.include_deleted:
deleted = None
if args.deleted_only:
deleted = True
review_result: Optional[bool] = True
if args.include_spam:
review_result = None
if args.spam_only:
review_result = False
with Session(engine) as db:
export_trees(
db,
Path(args.export_file) if args.export_file is not None else None,
args.use_compression,
deleted=deleted,
user_id=UUID(args.user) if args.user is not None else None,
prompts_only=args.prompts_only,
state_filter=state_filter,
lang=args.lang,
review_result=review_result,
)
if __name__ == "__main__":
main()
+11 -35
View File
@@ -18,7 +18,8 @@ from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.prompt_repository import PromptRepository, UserRepository
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_repository import User
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
@@ -147,6 +148,7 @@ if settings.DEBUG_USE_SEED_DATA:
ur = UserRepository(db=session, api_client=api_client)
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
pr = PromptRepository(
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
@@ -317,6 +319,13 @@ def update_user_streak(session: Session) -> None:
return
@app.on_event("startup")
@repeat_every(seconds=60 * 60) # 1 hour
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def cronjob_delete_expired_tasks(session: Session) -> None:
delete_expired_tasks(session)
app.include_router(api_router, prefix=settings.API_V1_STR)
@@ -324,24 +333,6 @@ def get_openapi_schema():
return json.dumps(app.openapi())
def export_ready_trees(file: Optional[str] = None, use_compression: bool = False):
try:
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
ur = UserRepository(db=db, api_client=api_client)
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
tm = TreeManager(db, pr)
tm.export_all_ready_trees(file, use_compression=use_compression)
except Exception:
logger.exception("Error exporting trees.")
def retry_scoring_failed_message_trees():
try:
logger.info("TreeManager.retry_scoring_failed_message_trees()")
@@ -373,17 +364,6 @@ def main():
)
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
parser.add_argument("--port", help="The port to run the server", default=8080)
parser.add_argument(
"--export",
default=False,
help="Export all trees which are ready for exporting.",
action="store_true",
)
parser.add_argument(
"--export-file",
type=str,
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
)
parser.add_argument(
"--retry-scoring",
default=False,
@@ -396,14 +376,10 @@ def main():
if args.print_openapi_schema:
print(get_openapi_schema())
if args.export:
use_compression: bool = ".gz" in args.export_file
export_ready_trees(file=args.export_file, use_compression=use_compression)
if args.retry_scoring:
retry_scoring_failed_message_trees()
if not (args.export or args.print_openapi_schema or args.retry_scoring):
if not (args.print_openapi_schema or args.retry_scoring):
uvicorn.run(app, host=args.host, port=args.port)
+35
View File
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import pydantic
@@ -162,3 +163,37 @@ async def purge_user_messages(
logger.info(f"{before=}; {after=}")
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
class FlaggedMessageResponse(pydantic.BaseModel):
message_id: UUID
processed: bool
created_date: Optional[datetime]
@router.get("/flagged_messages", response_model=list[FlaggedMessageResponse])
async def get_flagged_messages(
max_count: Optional[int],
session: deps.Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted
pr = PromptRepository(session, api_client)
flagged_messages = pr.fetch_flagged_messages(max_count=max_count)
resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages]
return resp
@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
async def process_flagged_messages(
message_id: UUID,
session: deps.Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted
pr = PromptRepository(session, api_client)
flagged_msg = pr.process_flagged_message(message_id=message_id)
resp = FlaggedMessageResponse(**flagged_msg.__dict__)
return resp
+2
View File
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
stats,
tasks,
text_labels,
trollboards,
users,
)
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
@@ -59,6 +59,37 @@ def query_frontend_user(
return user.to_protocol_frontend_user()
@router.post("/", response_model=protocol.FrontEndUser)
def create_frontend_user(
*,
create_user: protocol.CreateFrontendUserRequest,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
user = ur.lookup_client_user(create_user, create_missing=True)
def changed(a, b) -> bool:
return a is not None and a != b
# only call update_user if something changed
if (
changed(create_user.enabled, user.enabled)
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
or changed(create_user.notes, user.notes)
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
):
user = ur.update_user(
user.id,
enabled=create_user.enabled,
show_on_leaderboard=create_user.show_on_leaderboard,
tos_acceptance=create_user.tos_acceptance,
notes=create_user.notes,
)
return user.to_protocol_frontend_user()
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
auth_method: str,
+1 -1
View File
@@ -18,7 +18,7 @@ async def get_text_toxicity(
Args:
msg (str): the message that we want to analyze.
api_client (ApiClient, optional): authentification of the user of the request.
api_client (ApiClient, optional): authentication of the user of the request.
Defaults to Depends(deps.get_trusted_api_client).
Returns:
+73
View File
@@ -0,0 +1,73 @@
import aiohttp
from fastapi import APIRouter, Depends, HTTPException, Request
from oasst_backend import auth
from oasst_backend.api import deps
from oasst_backend.config import Settings
from oasst_backend.models import Account
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_401_UNAUTHORIZED
router = APIRouter()
@router.get("/discord")
def login_discord(request: Request):
redirect_uri = f"{get_callback_uri(request)}/discord"
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
raise HTTPException(status_code=302, headers={"location": auth_url})
@router.get("/callback/discord", response_model=protocol_schema.Token)
async def callback_discord(
auth_code: str,
request: Request,
db: Session = Depends(deps.get_db),
):
redirect_uri = f"{get_callback_uri(request)}/discord"
async with aiohttp.ClientSession(raise_for_status=True) as session:
# Exchange the auth code for a Discord access token
async with session.post(
"https://discord.com/api/oauth2/token",
data={
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
"scope": "identify",
},
) as token_response:
token_response_json = await token_response.json()
access_token = token_response_json["access_token"]
# Retrieve user's Discord information using access token
async with session.get(
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
) as user_response:
user_response_json = await user_response.json()
discord_id = user_response_json["id"]
account: Account = auth.get_account_from_discord_id(db, discord_id)
if not account:
# Discord account is not linked to an OA account
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
# Discord account is valid and linked to an OA account -> create JWT
access_token = auth.create_access_token(account)
return protocol_schema.Token(access_token=access_token, token_type="bearer")
def get_callback_uri(request: Request):
"""
Gets the URI for the base callback endpoint with no provider name appended.
"""
# This seems ugly, not sure if there is a better way
current_url = str(request.url)
domain = current_url.split("/api/v1/")[0]
redirect_uri = f"{domain}/api/v1/callback"
return redirect_uri
+3 -2
View File
@@ -104,7 +104,8 @@ def tasks_acknowledge(
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
logger.info(f"Frontend ACK task_id={task_id}")
logger.debug(f"{ack_request=}.")
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
@@ -131,7 +132,7 @@ def tasks_acknowledge_failure(
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
pr.task_repository.acknowledge_task_failure(task_id)
pr.skip_task(task_id=task_id, reason=nack_request.reason)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -0,0 +1,21 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import TrollboardStats
from sqlmodel import Session
router = APIRouter()
@router.get("/{time_frame}", response_model=TrollboardStats)
def get_trollboard(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
) -> TrollboardStats:
usr = UserStatsRepository(db)
return usr.get_trollboard(time_frame, limit=max_count)
+2 -1
View File
@@ -191,6 +191,7 @@ def update_user(
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
tos_acceptance: Optional[bool] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
@@ -198,7 +199,7 @@ def update_user(
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
+8
View File
@@ -17,8 +17,15 @@ def prepare_message(m: Message) -> protocol.Message:
created_date=m.created_date,
emojis=m.emojis or {},
user_emojis=m.user_emojis or [],
user_is_author=m.user_is_author,
review_result=m.review_result,
review_count=m.review_count,
ranking_count=m.ranking_count,
deleted=m.deleted,
synthetic=m.synthetic,
model_name=m.model_name,
message_tree_id=m.message_tree_id,
rank=m.rank,
)
@@ -36,6 +43,7 @@ def prepare_conversation_message(message: Message) -> protocol.ConversationMessa
is_assistant=(message.role == "assistant"),
emojis=message.emojis or {},
user_emojis=message.user_emojis or [],
user_is_author=message.user_is_author,
)
+37
View File
@@ -0,0 +1,37 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import jwt
from oasst_backend.config import Settings
from oasst_backend.models import Account
from sqlmodel import Session
def create_access_token(data: dict) -> str:
"""
Create an encoded JSON Web Token (JWT) using the given data.
"""
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
return encoded_jwt
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
"""
Get the Open-Assistant Account associated with the given Discord ID.
"""
account: Account = (
db.query(Account)
.filter(
Account.provider == "discord",
Account.provider_account_id == discord_id,
)
.first()
)
return account
+43 -9
View File
@@ -13,7 +13,10 @@ class TreeManagerConfiguration(BaseModel):
No new initial prompt tasks are handed out to users if this
number is reached."""
max_tree_depth: int = 6
max_initial_prompt_review: int = 100
"""Maximum number of initial prompts under review before no more initial prompt tasks will be handed out."""
max_tree_depth: int = 3
"""Maximum depth of message tree."""
max_children_count: int = 3
@@ -22,22 +25,39 @@ class TreeManagerConfiguration(BaseModel):
num_prompter_replies: int = 1
"""Number of prompter replies to collect per assistant reply."""
goal_tree_size: int = 15
goal_tree_size: int = 12
"""Total number of messages to gather per tree."""
random_goal_tree_size: bool = False
"""If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size]."""
min_goal_tree_size: int = 5
"""Minimum tree size for random goal sizes."""
num_reviews_initial_prompt: int = 3
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""
num_reviews_reply: int = 3
"""Number of peer review checks to collect per reply (other than initial_prompt)."""
p_full_labeling_review_prompt: float = 0.1
auto_mod_enabled: bool = True
"""Flag to enable/disable auto moderation."""
auto_mod_max_skip_reply: int = 25
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
of users skip replying to a message. (auto moderation)"""
auto_mod_red_flags: int = 4
"""Delete messages that receive more than this number of red flags if it is a reply or
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
p_full_labeling_review_prompt: float = 1.0
"""Probability of full text-labeling (instead of mandatory only) for initial prompts."""
p_full_labeling_review_reply_assistant: float = 0.1
p_full_labeling_review_reply_assistant: float = 0.5
"""Probability of full text-labeling (instead of mandatory only) for assistant replies."""
p_full_labeling_review_reply_prompter: float = 0.1
p_full_labeling_review_reply_prompter: float = 0.25
"""Probability of full text-labeling (instead of mandatory only) for prompter replies."""
acceptance_threshold_initial_prompt: float = 0.6
@@ -55,7 +75,7 @@ class TreeManagerConfiguration(BaseModel):
min_active_rankings_per_lang: int = 0
"""When the number of active ranking tasks is below this value when a tree enters a terminal
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
state an available trees in BACKLOG_RANKING will be activated (i.e. enters the RANKING state)."""
labels_initial_prompt: list[TextLabel] = [
TextLabel.spam,
@@ -112,13 +132,13 @@ class TreeManagerConfiguration(BaseModel):
rank_prompter_replies: bool = False
lonely_children_count: int = 3
lonely_children_count: int = 2
"""Number of children below which parents are preferred during sampling for reply tasks."""
p_lonely_child_extension: float = 0.8
p_lonely_child_extension: float = 0.75
"""Probability to select a prompter message parent with less than lonely_children_count children."""
recent_tasks_span_sec: int = 3 * 60 # 3 min
recent_tasks_span_sec: int = 5 * 60 # 5 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""
@@ -135,6 +155,11 @@ class Settings(BaseSettings):
AUTH_LENGTH: int = 32
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
AUTH_COOKIE_NAME: str = "next-auth.session-token"
AUTH_ALGORITHM: str = "HS256"
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
AUTH_DISCORD_CLIENT_ID: str = ""
AUTH_DISCORD_CLIENT_SECRET: str = ""
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
@@ -144,6 +169,9 @@ class Settings(BaseSettings):
DATABASE_URI: Optional[PostgresDsn] = None
DATABASE_MAX_TX_RETRY_COUNT: int = 3
DATABASE_POOL_SIZE = 75
DATABASE_MAX_OVERFLOW = 20
RATE_LIMIT: bool = True
MESSAGE_SIZE_LIMIT: int = 2000
REDIS_HOST: str = "localhost"
@@ -154,10 +182,14 @@ class Settings(BaseSettings):
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
)
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
DEBUG_ALLOW_SELF_RANKING: bool = False # allow users to rank their own messages
DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
True # TODO: set False after ToS acceptance UI was added to web-frontend
)
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
@@ -214,6 +246,8 @@ class Settings(BaseSettings):
RATE_LIMIT_TASK_API_TIMES: int = 10_000
RATE_LIMIT_TASK_API_MINUTES: int = 1
TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
+7 -1
View File
@@ -5,4 +5,10 @@ from sqlmodel import create_engine
if settings.DATABASE_URI is None:
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
engine = create_engine(settings.DATABASE_URI, echo=settings.DEBUG_DATABASE_ECHO, isolation_level="REPEATABLE READ")
engine = create_engine(
settings.DATABASE_URI,
echo=settings.DEBUG_DATABASE_ECHO,
isolation_level="REPEATABLE READ",
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
)
+4
View File
@@ -1,4 +1,5 @@
from .api_client import ApiClient
from .flagged_message import FlaggedMessage
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
@@ -8,6 +9,7 @@ from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .troll_stats import TrollStats
from .user import User
from .user_stats import UserStats, UserStatsTimeFrame
@@ -26,4 +28,6 @@ __all__ = [
"Journal",
"JournalIntegration",
"MessageEmoji",
"TrollStats",
"FlaggedMessage",
]
+5 -3
View File
@@ -3,7 +3,7 @@ from uuid import UUID
from oasst_backend.models.payload_column_type import payload_type
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
from pydantic import BaseModel, Field
@payload_type
@@ -117,8 +117,10 @@ class LabelConversationReplyPayload(TaskPayload):
message_id: UUID
conversation: protocol_schema.Conversation
reply: str # deprecated
reply_message: Optional[protocol_schema.ConversationMessage]
reply: Optional[str] = Field(None, deprecated=True, description="deprecated")
reply_message: Optional[protocol_schema.ConversationMessage] = Field(
None, deprecated=True, description="deprecated"
)
valid_labels: list[str]
mandatory_labels: Optional[list[str]]
mode: Optional[protocol_schema.LabelTaskMode]
@@ -0,0 +1,23 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class FlaggedMessage(SQLModel, table=True):
__tablename__ = "flagged_message"
message_id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
processed: bool = Field(nullable=False, index=True)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
)
)
+5
View File
@@ -64,6 +64,7 @@ class Message(SQLModel, table=True):
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
_user_is_author: Optional[bool] = PrivateAttr(default=None)
def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
@@ -83,3 +84,7 @@ class Message(SQLModel, table=True):
@property
def user_emojis(self) -> str:
return self._user_emojis
@property
def user_is_author(self) -> str:
return self._user_is_author
@@ -1,4 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
@@ -10,7 +12,7 @@ class State(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single inital prompt root node.
"""In this state the message tree consists only of a single initial prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `growing` or
`aborted_low_grade` state."""
@@ -31,11 +33,11 @@ class State(str, Enum):
compute the aggergated ranking scores that will appear in the dataset."""
READY_FOR_EXPORT = "ready_for_export"
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
"""The Scoring algorithm computed rankings scores for all children. The message tree can be
exported as part of an Open-Assistant message tree dataset."""
SCORING_FAILED = "scoring_failed"
"""An exception occured in the scoring algorithm."""
"""An exception occurred in the scoring algorithm."""
ABORTED_LOW_GRADE = "aborted_low_grade"
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
@@ -46,6 +48,9 @@ class State(str, Enum):
BACKLOG_RANKING = "backlog_ranking"
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting"
"""Initial prompt has passed spam check, waiting to be drawn to grow."""
VALID_STATES = (
State.INITIAL_PROMPT_REVIEW,
@@ -63,6 +68,7 @@ TERMINAL_STATES = (
State.SCORING_FAILED,
State.HALTED_BY_MODERATOR,
State.BACKLOG_RANKING,
State.PROMPT_LOTTERY_WAITING,
)
@@ -78,3 +84,4 @@ class MessageTreeState(SQLModel, table=True):
state: str = Field(nullable=False, max_length=128, index=True)
active: bool = Field(nullable=False, index=True)
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
+2
View File
@@ -31,6 +31,8 @@ class Task(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
ack: Optional[bool] = None
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
skipped: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
skip_reason: str = Field(nullable=True, max_length=512)
frontend_message_id: Optional[str] = None
message_tree_id: Optional[UUID] = None
parent_message_id: Optional[UUID] = None
@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class TrollStats(SQLModel, table=True):
__tablename__ = "troll_stats"
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
troll_score: int = 0
modified_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: float = Field(nullable=True)
humor: float = Field(nullable=True)
toxicity: float = Field(nullable=True)
violence: float = Field(nullable=True)
helpfulness: float = Field(nullable=True)
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
def compute_troll_score(self) -> int:
return (
self.red_flags * 3
- self.upvotes
+ self.downvotes
+ self.spam_prompts
+ self.lang_mismach
+ self.not_appropriate
+ self.pii
+ self.hate_speech
+ self.sexual_content
+ self.political_content
)
+18
View File
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
)
# terms of service acceptance date
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
def to_protocol_frontend_user(self):
return protocol.FrontEndUser(
user_id=self.id,
@@ -55,4 +58,19 @@ class User(SQLModel, table=True):
streak_days=self.streak_days,
streak_last_day_date=self.streak_last_day_date,
last_activity_date=self.last_activity_date,
tos_acceptance_date=self.tos_acceptance_date,
)
class Account(SQLModel, table=True):
__tablename__ = "account"
__table_args__ = (Index("provider", "provider_account_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
user_id: UUID = Field(foreign_key="user.id")
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
provider_account_id: str = Field(nullable=False, max_length=128)
+102 -17
View File
@@ -7,12 +7,14 @@ from typing import Optional
from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
import sqlalchemy.dialects.postgresql as pg
from loguru import logger
from oasst_backend.api.deps import FrontendUserId
from oasst_backend.config import settings
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import (
ApiClient,
FlaggedMessage,
Message,
MessageEmbedding,
MessageEmoji,
@@ -35,7 +37,21 @@ from oasst_shared.utils import unaware_to_utc, utcnow
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
_task_type_and_reaction = (
(
(db_payload.PrompterReplyPayload, db_payload.AssistantReplyPayload),
protocol_schema.EmojiCode.skip_reply,
),
(
(db_payload.LabelInitialPromptPayload, db_payload.LabelConversationReplyPayload),
protocol_schema.EmojiCode.skip_labeling,
),
(
(db_payload.RankInitialPromptsPayload, db_payload.RankConversationRepliesPayload),
protocol_schema.EmojiCode.skip_ranking,
),
)
class PromptRepository:
@@ -77,7 +93,14 @@ class PromptRepository:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
if self.user.deleted or not self.user.enabled:
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
raise OasstError(
"User has not accepted terms of service.",
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
)
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
@@ -90,7 +113,7 @@ class PromptRepository:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.",
OasstErrorCode.MESSAGE_NOT_FOUND,
HTTP_404_NOT_FOUND,
HTTPStatus.NOT_FOUND,
)
return message
@@ -134,12 +157,15 @@ class PromptRepository:
review_result=review_result,
)
self.db.add(message)
# self.db.refresh(message)
return message
def _validate_task(
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
self,
task: Task,
*,
task_id: Optional[UUID] = None,
frontend_message_id: Optional[str] = None,
check_ack: bool = True,
) -> Task:
if task is None:
if task_id:
@@ -150,7 +176,7 @@ class PromptRepository:
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
if check_ack and not task.ack:
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
@@ -262,6 +288,10 @@ class PromptRepository:
task.done = True
self.db.add(task)
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
logger.debug(
f"Inserted message id={user_message.id}, tree={user_message.message_tree_id}, user_id={user_message.user_id}, "
f"text[:100]='{user_message.text[:100]}', role='{user_message.role}', lang='{user_message.lang}'"
)
return user_message
@managed_tx_method(CommitMode.FLUSH)
@@ -456,7 +486,7 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
logger.debug(f"text_labels relpy: {valid_labels=}, {mandatory_labels=}")
logger.debug(f"text_labels reply: {valid_labels=}, {mandatory_labels=}")
if valid_labels:
if not all([label in valid_labels for label in text_labels.labels.keys()]):
@@ -628,12 +658,6 @@ class PromptRepository:
qry = qry.filter(not_(Message.deleted))
return self._add_user_emojis_all(qry)
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
qry = self.db.query(MessageTreeState).filter(
MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT
)
return qry.all()
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
"""
Fetch a conversation with multiple possible replies to it.
@@ -675,7 +699,7 @@ class PromptRepository:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
return message
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
@@ -848,6 +872,7 @@ class PromptRepository:
user_emojis = x["user_emojis"]
if user_emojis:
m._user_emojis = user_emojis.split(",")
m._user_is_author = self.user_id and self.user_id == m.user_id
messages.append(m)
return messages
@@ -874,7 +899,7 @@ class PromptRepository:
if api_client_id != self.api_client.id:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
qry = self.db.query(Message)
if user_id:
@@ -995,7 +1020,31 @@ WHERE message.id = cc.id;
message_trees=result.get(None, 0),
)
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
@managed_tx_method()
def skip_task(self, task_id: UUID, reason: str):
self.ensure_user_is_enabled()
task = self.task_repository.fetch_task_by_id(task_id)
self._validate_task(task, check_ack=False)
if not task.collective:
task.skipped = True
task.skip_reason = reason
self.db.add(task)
def handle_cancel_emoji(task_payload: db_payload.TaskPayload) -> Message | None:
for types, emoji in _task_type_and_reaction:
for t in types:
if isinstance(task_payload, t):
return self.handle_message_emoji(task.parent_message_id, protocol_schema.EmojiOp.add, emoji)
return None
task_payload: db_payload.TaskPayload = task.payload.payload
handle_cancel_emoji(task_payload)
def handle_message_emoji(
self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema.EmojiCode
) -> Message:
self.ensure_user_is_enabled()
message = self.fetch_message(message_id)
@@ -1038,6 +1087,22 @@ WHERE message.id = cc.id;
message_id, protocol_schema.EmojiOp.remove, protocol_schema.EmojiCode.thumbs_up
)
if message.user_id == self.user_id and emoji in (
protocol_schema.EmojiCode.thumbs_up,
protocol_schema.EmojiCode.thumbs_down,
):
logger.debug(f"Ignoring add emoji op for user's own message ({emoji=})")
return message
# Add to flagged_message table if the red flag emoji is applied
if emoji == protocol_schema.EmojiCode.red_flag:
flagged_message = FlaggedMessage(message_id=message_id, processed=False, created_date=utcnow())
insert_stmt = pg.insert(FlaggedMessage).values(**flagged_message.dict())
upsert_stmt = insert_stmt.on_conflict_do_update(
constraint="flagged_message_pkey", set_=flagged_message.dict()
)
self.db.execute(upsert_stmt)
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
@@ -1073,3 +1138,23 @@ WHERE message.id = cc.id;
self.db.add(message)
self.db.flush()
return message
def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessage]:
qry = self.db.query(FlaggedMessage)
if max_count is not None:
qry = qry.limit(max_count)
return qry.all()
def process_flagged_message(self, message_id: UUID) -> FlaggedMessage:
message = self.db.query(FlaggedMessage).get(message_id)
if not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
message.processed = True
self.db.commit()
self.db.refresh(message)
return message
+27 -24
View File
@@ -1,16 +1,18 @@
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, or_
from oasst_shared.utils import utcnow
from sqlmodel import Session, delete, false, func, or_
from starlette.status import HTTP_404_NOT_FOUND
@@ -24,6 +26,13 @@ def validate_frontend_message_id(message_id: str) -> None:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def delete_expired_tasks(session: Session) -> int:
stm = delete(Task).where(Task.expiry_date < utcnow(), Task.done == false())
result = session.exec(stm)
logger.info(f"Deleted {result.rowcount} expired tasks.")
return result.rowcount
class TaskRepository:
def __init__(
self,
@@ -100,8 +109,6 @@ class TaskRepository:
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
reply_message=task.reply_message,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
@@ -112,8 +119,6 @@ class TaskRepository:
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
reply_message=task.reply_message,
valid_labels=task.valid_labels,
mandatory_labels=task.mandatory_labels,
mode=task.mode,
@@ -122,12 +127,18 @@ class TaskRepository:
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
if not collective and settings.TASK_VALIDITY_MINUTES > 0:
expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES)
else:
expiry_date = None
task_model = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
expiry_date=expiry_date,
)
assert task_model.id == task.id
return task_model
@@ -166,26 +177,11 @@ class TaskRepository:
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
raise OasstError("Already closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
@managed_tx_method(CommitMode.COMMIT)
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
@managed_tx_method(CommitMode.COMMIT)
def insert_task(
self,
@@ -194,6 +190,7 @@ class TaskRepository:
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
expiry_date: datetime = None,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
@@ -205,6 +202,7 @@ class TaskRepository:
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
expiry_date=expiry_date,
)
logger.debug(f"inserting {task=}")
self.db.add(task)
@@ -224,14 +222,19 @@ class TaskRepository:
return task
def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, limit: int = 100
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(Task.created_date) < max_age,
func.age(func.current_timestamp(), Task.created_date) < max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
qry = qry.filter(Task.done == done)
if skipped is not None:
qry = qry.filter(Task.skipped == skipped)
if limit:
qry = qry.limit(limit)
return qry.all()
def delete_expired(self) -> int:
return delete_expired_tasks(self.db)
+367 -112
View File
@@ -1,6 +1,4 @@
import json
import random
import sys
from datetime import datetime, timedelta
from enum import Enum
from http import HTTPStatus
@@ -9,23 +7,34 @@ from uuid import UUID
import numpy as np
import pydantic
from fastapi.encoders import jsonable_encoder
import sqlalchemy as sa
from loguru import logger
from oasst_backend.api.v1.utils import (
prepare_conversation,
prepare_conversation_message,
prepare_conversation_message_list,
)
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
from oasst_backend.models import (
Message,
MessageEmoji,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
User,
message_tree_state,
)
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils import tree_export
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.database_utils import (
CommitMode,
async_managed_tx_method,
managed_tx_function,
managed_tx_method,
)
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, not_, or_, text, update
from oasst_shared.utils import utcnow
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, and_, func, not_, or_, text, update
class TaskType(Enum):
@@ -43,6 +52,19 @@ class TaskRole(Enum):
ASSISTANT = 2
class TreeStateStats(pydantic.BaseModel):
initial_prompt_review: int
growing: int
ranking: int
ready_for_scoring: int
scoring_failed: int
ready_for_export: int
aborted_low_grade: int
halted_by_moderator: int
backlog_ranking: int
prompt_lottery_waiting: int
class ActiveTreeSizeRow(pydantic.BaseModel):
message_tree_id: UUID
goal_tree_size: int
@@ -157,7 +179,7 @@ class TreeManager:
def _determine_task_availability_internal(
self,
num_active_trees: int,
num_missing_prompts: int,
extendible_parents: list[ExtendibleParentRow],
prompts_need_review: list[Message],
replies_need_review: list[Message],
@@ -165,7 +187,6 @@ class TreeManager:
) -> dict[protocol_schema.TaskRequestType, int]:
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
@@ -198,6 +219,117 @@ class TreeManager:
return task_count_by_type
def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
# Under high load the DB runs into deadlocks when many trees are released
# simultaneously (happens whens the max_active_trees setting is increased).
# To reduce the chance of write conflicts during updates of rows in the
# message_tree_state table we limit the number of trees that are activated
# per _prompt_lottery() call to max_activate.
activated = 0
while True:
stats = self.tree_counts_by_state_stats(lang=lang, only_active=True)
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
if num_missing_growing == 0 or activated >= max_activate:
return num_missing_growing + remaining_prompt_review
@managed_tx_function(CommitMode.COMMIT)
def activate_one(db: Session) -> int:
# select among distinct users
authors_qry = (
db.query(Message.user_id)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.distinct(Message.user_id)
)
author_ids = authors_qry.all()
if len(author_ids) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
)
return False
# first select an authour
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
# select random prompt of author
qry = (
db.query(MessageTreeState, Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.user_id == prompt_author_id,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.limit(100)
)
prompt_candidates = qry.all()
if len(prompt_candidates) == 0:
logger.warning("No prompt candidates of selected author found.")
return False
winner_prompt = random.choice(prompt_candidates)
message: Message = winner_prompt.Message
logger.info(f"Prompt lottery winner: {message.id=}")
mts: MessageTreeState = winner_prompt.MessageTreeState
mts.state = message_tree_state.State.GROWING
mts.active = True
db.add(mts)
if mts.won_prompt_lottery_date is None:
mts.won_prompt_lottery_date = utcnow()
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
return True
if not activate_one():
return num_missing_growing + remaining_prompt_review
activated += 1
def _auto_moderation(self, lang: str) -> None:
if not self.cfg.auto_mod_enabled:
return
bad_messages = self.query_moderation_bad_messages(lang=lang)
for m in bad_messages:
num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag)
if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags:
if m.parent_id is None:
logger.warning(
f"[AUTO MOD] Halting tree {m.message_tree_id}, initial prompt got too many red flags ({m.emojis})."
)
self.enter_low_grade_state(m.message_tree_id)
else:
logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).")
self.pr.mark_messages_deleted(m.id, recursive=True)
num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply)
if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply:
logger.warning(
f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})."
)
self.halt_tree(m.id, halt=True)
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()
@@ -205,14 +337,15 @@ class TreeManager:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1)
extendible_parents, _ = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
num_missing_prompts=num_missing_prompts,
extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
@@ -242,7 +375,9 @@ class TreeManager:
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
@@ -260,7 +395,7 @@ class TreeManager:
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_prompts=num_missing_prompts,
num_missing_replies=num_missing_replies,
)
@@ -272,7 +407,7 @@ class TreeManager:
)
else:
task_count_by_type = self._determine_task_availability_internal(
num_active_trees=num_active_trees,
num_missing_prompts=num_missing_prompts,
extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
@@ -386,7 +521,6 @@ class TreeManager:
message_id=message.id,
conversation=conversation,
reply=message.text,
reply_message=prepare_conversation_message(message),
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
mode=label_mode,
@@ -412,7 +546,6 @@ class TreeManager:
message_id=message.id,
conversation=conversation,
reply=message.text,
reply_message=prepare_conversation_message(message),
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
mode=label_mode,
@@ -425,11 +558,6 @@ class TreeManager:
case TaskType.REPLY:
recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), done=False
)
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}
if task_role == TaskRole.PROMPTER:
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
elif task_role == TaskRole.ASSISTANT:
@@ -440,24 +568,17 @@ class TreeManager:
random_parent: ExtendibleParentRow = None
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
# check if we have extendible prompter parents with a small number of replies
lonely_children_parents = [
p
for p in extendible_parents
if 0 < p.active_children_count < self.cfg.lonely_children_count
and p.parent_role == "prompter"
and p.parent_id not in recent_reply_task_parents
]
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
random_parent = random.choice(lonely_children_parents)
if random_parent is None:
# try to exclude parents for which tasks were recently handed out
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
if len(fresh_parents) > 0:
random_parent = random.choice(fresh_parents)
else:
random_parent = random.choice(extendible_parents)
random_parent = random.choice(extendible_parents)
# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
@@ -479,6 +600,7 @@ class TreeManager:
case TaskType.LABEL_PROMPT:
assert len(prompts_need_review) > 0
message = random.choice(prompts_need_review)
message = self.pr.fetch_message(message.id) # re-fetch message including emojis
label_mode = protocol_schema.LabelTaskMode.full
label_disposition = protocol_schema.LabelTaskDisposition.quality
@@ -495,6 +617,7 @@ class TreeManager:
task = protocol_schema.LabelInitialPromptTask(
message_id=message.id,
prompt=message.text,
conversation=prepare_conversation([message]),
valid_labels=list(map(lambda x: x.value, valid_labels)),
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_initial_prompt)),
mode=label_mode,
@@ -519,7 +642,8 @@ class TreeManager:
HTTPStatus.SERVICE_UNAVAILABLE,
)
logger.info(f"Generated {task=}.")
logger.info(f"Generated task (type={task.type}, id={task.id})")
logger.debug(f"Generated {task=}.")
return task, message_tree_id, parent_message_id
@@ -530,8 +654,9 @@ class TreeManager:
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
f"Frontend reports text reply to message_id={interaction.message_id} by user={interaction.user}."
)
logger.debug(f"with {interaction.text=}")
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
@@ -578,23 +703,26 @@ class TreeManager:
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
f"Frontend reports rating of message_id={interaction.message_id} by user={interaction.user}."
)
logger.debug(f"with {interaction.rating=}")
pr.store_rating(interaction)
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
f"Frontend reports ranking of message_id={interaction.message_id} by user={interaction.user}."
)
logger.debug(f"with {interaction.ranking=}")
_, task = pr.store_ranking(interaction)
self.check_condition_for_scoring_state(task.message_tree_id)
case protocol_schema.TextLabels:
logger.info(
f"Frontend reports labels of {interaction.message_id=} with {interaction.labels=} by {interaction.user=}."
f"Frontend reports labels of message_id={interaction.message_id} by user={interaction.user}."
)
logger.debug(f"with {interaction.labels=}")
_, task, msg = pr.store_text_labels(interaction)
@@ -615,7 +743,7 @@ class TreeManager:
)
else:
self.enter_low_grade_state(msg.message_tree_id)
self.check_condition_for_growing_state(msg.message_tree_id)
self.check_condition_for_prompt_lottery(msg.message_tree_id)
elif msg.review_count >= self.cfg.num_reviews_reply:
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
msg.review_result = True
@@ -649,10 +777,12 @@ class TreeManager:
self.activate_backlog_tree(lang=root_msg.lang)
if self.cfg.min_active_rankings_per_lang > 0:
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang, user_filter=False)
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
self.activate_backlog_tree(lang=root_msg.lang)
else:
if mts.state == message_tree_state.State.GROWING and mts.won_prompt_lottery_date is None:
mts.won_prompt_lottery_date = utcnow()
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
@@ -660,8 +790,8 @@ class TreeManager:
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
def check_condition_for_prompt_lottery(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_prompt_lottery({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
@@ -674,7 +804,7 @@ class TreeManager:
logger.debug(f"False {initial_prompt.review_result=}")
return False
self._enter_state(mts, message_tree_state.State.GROWING)
self._enter_state(mts, message_tree_state.State.PROMPT_LOTTERY_WAITING)
return True
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
@@ -746,7 +876,7 @@ class TreeManager:
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
continue
# keep only elements in commond set
# keep only elements in common set
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
assert all(len(x) == len(common_set) for x in ordered_ids_list)
@@ -824,6 +954,14 @@ class TreeManager:
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.outerjoin(
MessageEmoji,
and_(
Message.id == MessageEmoji.message_id,
MessageEmoji.user_id == self.pr.user_id,
MessageEmoji.emoji == protocol_schema.EmojiCode.skip_labeling,
),
)
.filter(
MessageTreeState.active,
MessageTreeState.state == state,
@@ -831,6 +969,7 @@ class TreeManager:
not_(Message.deleted),
Message.review_count < required_reviews,
Message.lang == lang,
MessageEmoji.message_id.is_(None),
)
)
@@ -883,14 +1022,23 @@ SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) chi
mts.message_tree_id
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
INNER JOIN message p ON m.parent_id = p.id
LEFT JOIN message_emoji me on
(m.parent_id = me.message_id
AND :skip_user_id IS NOT NULL
AND me.user_id = :skip_user_id
AND me.emoji = :skip_ranking)
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND m.lang = :lang -- matches lang
AND p.lang = :lang -- parent lang matches
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
AND me.message_id IS NULL -- no skip ranking emoji for user
GROUP BY m.parent_id, m.role, mts.message_tree_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
HAVING COUNT(m.id) > 1 -- more than one child
AND MIN(m.ranking_count) < :num_required_rankings -- not complete
AND COUNT(m.id) FILTER (WHERE m.user_id = :rank_user_id) = 0 -- no self-ranking
"""
_sql_find_incomplete_rankings_ex = f"""
@@ -900,29 +1048,54 @@ SELECT ir.* FROM incomplete_rankings ir
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
ir.message_tree_id
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
HAVING COUNT(mr.message_id) FILTER (WHERE mr.user_id = :dupe_user_id) = 0
"""
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""
def query_incomplete_rankings(self, lang: str, user_filter: bool = True) -> list[IncompleteRankingsRow]:
"""Query parents which have children that need further rankings"""
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
dupe_user_id = None
skip_user_id = None
rank_user_id = None
if user_filter:
if not settings.DEBUG_ALLOW_DUPLICATE_TASKS:
dupe_user_id = self.pr.user_id
if not settings.DEBUG_ALLOW_SELF_RANKING:
rank_user_id = self.pr.user_id
skip_user_id = self.pr.user_id
r = self.db.execute(
text(self._sql_find_incomplete_rankings_ex),
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
"lang": lang,
"user_id": user_id,
"dupe_user_id": dupe_user_id,
"skip_user_id": skip_user_id,
"rank_user_id": rank_user_id,
"ranking_state": message_tree_state.State.RANKING,
"skip_ranking": protocol_schema.EmojiCode.skip_ranking,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
_sql_find_extendible_parents = """
-- find all extendible parent nodes
WITH recent_reply_tasks (parent_message_id) AS (
-- recent incomplete tasks to exclude
SELECT parent_message_id FROM task
WHERE not done
AND not skipped
AND created_date > (CURRENT_TIMESTAMP - :recent_tasks_interval)
AND (payload_type = 'AssistantReplyPayload' OR payload_type = 'PrompterReplyPayload')
)
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message_emoji me ON
(m.id = me.message_id
AND :skip_user_id IS NOT NULL
AND me.user_id = :skip_user_id
AND me.emoji = :skip_reply)
LEFT JOIN recent_reply_tasks rrt ON m.id = rrt.parent_message_id -- recent tasks
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
@@ -930,6 +1103,8 @@ WHERE mts.active -- only consider active trees
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND m.lang = :lang -- parent matches lang
AND me.message_id IS NULL -- no skip reply emoji for user
AND rrt.parent_message_id IS NULL -- no recent reply task found
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
@@ -950,6 +1125,9 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"num_prompter_replies": self.cfg.num_prompter_replies,
"lang": lang,
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
},
)
@@ -988,6 +1166,9 @@ HAVING COUNT(m.id) < mts.goal_tree_size
"num_prompter_replies": self.cfg.num_prompter_replies,
"lang": lang,
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
@@ -1079,7 +1260,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
@managed_tx_method(CommitMode.COMMIT)
def ensure_tree_states(self) -> None:
"""Add message tree state rows for all root nodes (inital prompt messages)."""
"""Add message tree state rows for all root nodes (initial prompt messages)."""
missing_tree_ids = self.query_misssing_tree_states()
for id in missing_tree_ids:
@@ -1101,7 +1282,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state."
)
for t in prompt_review_trees:
self.check_condition_for_growing_state(t.message_tree_id)
self.check_condition_for_prompt_lottery(t.message_tree_id)
growing_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
@@ -1129,8 +1310,23 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
for t in ranking_trees:
self.check_condition_for_scoring_state(t.message_tree_id)
def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int:
"""Count all active trees (optionally exclude those in ranking state)."""
def query_num_growing_trees(self, lang: str) -> int:
"""Count all active trees in growing state."""
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.GROWING,
Message.lang == lang,
)
)
return query.scalar()
def query_num_active_trees(
self, lang: str, exclude_ranking: bool = True, exclude_prompt_review: bool = True
) -> int:
"""Count all active trees (optionally exclude those in ranking and initial prompt review states)."""
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
@@ -1138,6 +1334,8 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
)
if exclude_ranking:
query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING)
if exclude_prompt_review:
query = query.filter(MessageTreeState.state != message_tree_state.State.INITIAL_PROMPT_REVIEW)
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
@@ -1149,6 +1347,37 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
)
return qry.all()
def query_moderation_bad_messages(self, lang: str) -> list[Message]:
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
or_(
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
MessageTreeState.state == message_tree_state.State.GROWING,
),
or_(
Message.parent_id.is_(None),
Message.review_result,
and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply),
),
not_(Message.deleted),
or_(
coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0)
>= self.cfg.auto_mod_red_flags,
coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0)
>= self.cfg.auto_mod_max_skip_reply,
),
)
)
if lang is not None:
qry = qry.filter(Message.lang == lang)
return qry.all()
@managed_tx_method(CommitMode.FLUSH)
def _insert_tree_state(
self,
@@ -1176,22 +1405,54 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
self,
root_message_id: UUID,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
*,
goal_tree_size: int = None,
) -> MessageTreeState:
if goal_tree_size is None:
if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size:
goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
else:
goal_tree_size = self.cfg.goal_tree_size
return self._insert_tree_state(
root_message_id=root_message_id,
goal_tree_size=self.cfg.goal_tree_size,
goal_tree_size=goal_tree_size,
max_depth=self.cfg.max_tree_depth,
max_children_count=self.cfg.max_children_count,
state=state,
active=True,
)
def tree_counts_by_state(self) -> dict[str, int]:
qry = self.db.query(
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
).group_by(MessageTreeState.state)
def tree_counts_by_state(self, lang: str = None, only_active: bool = False) -> dict[str, int]:
qry = self.db.query(MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count"))
if lang is not None:
qry = (
qry.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(Message.lang == lang)
)
if only_active:
qry = qry.filter(MessageTreeState.active)
qry = qry.group_by(MessageTreeState.state)
return {x["state"]: x["count"] for x in qry}
def tree_counts_by_state_stats(self, lang: str = None, only_active: bool = False) -> TreeStateStats:
count_by_state = self.tree_counts_by_state(lang=lang, only_active=only_active)
r = TreeStateStats(
initial_prompt_review=count_by_state.get(message_tree_state.State.INITIAL_PROMPT_REVIEW) or 0,
growing=count_by_state.get(message_tree_state.State.GROWING) or 0,
ranking=count_by_state.get(message_tree_state.State.RANKING) or 0,
ready_for_scoring=count_by_state.get(message_tree_state.State.READY_FOR_SCORING) or 0,
ready_for_export=count_by_state.get(message_tree_state.State.READY_FOR_EXPORT) or 0,
scoring_failed=count_by_state.get(message_tree_state.State.SCORING_FAILED) or 0,
halted_by_moderator=count_by_state.get(message_tree_state.State.HALTED_BY_MODERATOR) or 0,
backlog_ranking=count_by_state.get(message_tree_state.State.BACKLOG_RANKING) or 0,
prompt_lottery_waiting=count_by_state.get(message_tree_state.State.PROMPT_LOTTERY_WAITING) or 0,
aborted_low_grade=count_by_state.get(message_tree_state.State.ABORTED_LOW_GRADE) or 0,
)
return r
def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
qry = (
self.db.query(
@@ -1274,9 +1535,32 @@ DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id
DELETE FROM task t WHERE t.parent_message_id = :message_id;
DELETE FROM message WHERE id = :message_id;
"""
parent_id = self.pr.fetch_message(message_id=message_id).parent_id
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
sql_update_ranking_counts = """
WITH r AS (
-- find ranking results and count per child
SELECT c.id,
count(*) FILTER (
WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar)
) AS ranking_count
FROM message c
LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload'
AND mr.message_id = c.parent_id
WHERE c.parent_id = :parent_id
GROUP BY c.id
)
UPDATE message m SET ranking_count = r.ranking_count
FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count;
"""
if parent_id is not None:
# update ranking counts of remaining children
r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id})
logger.debug(f"ranking_count updated for {r.rowcount} rows.")
def purge_message_tree(self, message_tree_id: UUID) -> None:
sql_purge_message_tree = """
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
@@ -1292,11 +1576,17 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
def _reactivate_tree(self, mts: MessageTreeState):
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
if mts.state == message_tree_state.State.PROMPT_LOTTERY_WAITING:
return
tree_id = mts.message_tree_id
if self.check_condition_for_growing_state(tree_id):
if mts.won_prompt_lottery_date is not None:
self._enter_state(mts, message_tree_state.State.GROWING)
if self.check_condition_for_ranking_state(tree_id):
self.check_condition_for_scoring_state(tree_id)
else:
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
self.check_condition_for_prompt_lottery(tree_id)
@managed_tx_method(CommitMode.FLUSH)
def purge_user_messages(
@@ -1312,7 +1602,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
total_messages = sum(len(x) for x in replies_by_tree.values())
logger.debug(f"found: {len(replies_by_tree)} trees; {len(prompts)} prompts; {total_messages} messages;")
# remove all trees based on inital prompts of the user
# remove all trees based on initial prompts of the user
if purge_initial_prompts:
for p in prompts:
self.purge_message_tree(p.message_tree_id)
@@ -1350,7 +1640,7 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
logger.debug(f"purging message: {m.id}")
self._purge_message_internal(m.id)
# update childern counts
# update children counts
self.pr.update_children_counts(m.message_tree_id)
# reactivate tree
@@ -1379,44 +1669,6 @@ DELETE FROM user_stats WHERE user_id = :user_id;
if ban:
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
def export_trees_to_file(
self,
message_tree_ids: list[str],
file=None,
reviewed: bool = True,
include_deleted: bool = False,
use_compression: bool = False,
) -> None:
trees_to_export: List[tree_export.ExportMessageTree] = []
for message_tree_id in message_tree_ids:
messages: List[Message] = self.pr.fetch_message_tree(message_tree_id, reviewed, include_deleted)
trees_to_export.append(tree_export.build_export_tree(message_tree_id, messages))
if file:
tree_export.write_trees_to_file(file, trees_to_export, use_compression)
else:
sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2))
def export_all_ready_trees(
self, file: str, reviewed: bool = True, include_deleted: bool = False, use_compression: bool = False
) -> None:
message_tree_states: MessageTreeState = self.pr.fetch_message_trees_ready_for_export()
message_tree_ids = [ms.message_tree_id for ms in message_tree_states]
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
def export_all_user_trees(
self,
user_id: str,
file: str,
reviewed: bool = True,
include_deleted: bool = False,
use_compression: bool = False,
) -> None:
messages = self.pr.fetch_user_message_trees(UUID(user_id))
message_tree_ids = [ms.message_tree_id for ms in messages]
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
@managed_tx_method(CommitMode.COMMIT)
def retry_scoring_failed_message_trees(self):
query = self.db.query(MessageTreeState).filter(
@@ -1454,7 +1706,8 @@ if __name__ == "__main__":
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
# dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
dummy_user = protocol_schema.User(id="1234", display_name="bulb", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
@@ -1469,14 +1722,16 @@ if __name__ == "__main__":
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
xs = tm.query_prompts_need_review(lang="en")
print("xs", len(xs))
for x in xs:
print(x.id, x.emojis)
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review(lang="en"))
# print("query_extendible_trees", tm.query_extendible_trees())
# print("query_extendible_parents", tm.query_extendible_parents())
# print("next_task:", tm.next_task())
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
)
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
# print(
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
# )
+12 -2
View File
@@ -73,7 +73,8 @@ class UserRepository:
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
) -> None:
tos_acceptance: Optional[bool] = None,
) -> User:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
@@ -94,8 +95,11 @@ class UserRepository:
user.notes = notes
if show_on_leaderboard is not None:
user.show_on_leaderboard = show_on_leaderboard
if tos_acceptance:
user.tos_acceptance_date = utcnow()
self.db.add(user)
return user
@managed_tx_method(CommitMode.COMMIT)
def mark_user_deleted(self, id: UUID) -> None:
@@ -143,8 +147,10 @@ class UserRepository:
display_name=display_name,
api_client_id=self.api_client.id,
auth_method=auth_method,
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
)
if auth_method == "system":
user.show_on_leaderboard = False # don't show system users, e.g. import user
user.tos_acceptance_date = utcnow()
self.db.add(user)
elif display_name and display_name != user.display_name:
# we found the user but the display name changed
@@ -156,6 +162,10 @@ class UserRepository:
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
if not client_user:
return None
if not (client_user.auth_method and client_user.id):
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
for i in range(num_retries):
try:
+274 -17
View File
@@ -5,19 +5,40 @@ from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models import (
Message,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
TrollStats,
User,
UserStats,
UserStatsTimeFrame,
)
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
LabelInitialPromptPayload,
LabelPrompterReplyPayload,
RankingReactionPayload,
)
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_shared.schemas.protocol import (
EmojiCode,
LabelTaskMode,
LeaderboardStats,
TextLabel,
TrollboardStats,
TrollScore,
UserScore,
)
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, delete, func, text
def _create_user_score(r, highlighted_user_id: UUID | None):
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
if r["UserStats"]:
d = r["UserStats"].dict()
else:
@@ -37,6 +58,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
return UserScore(**d)
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
if r["TrollStats"]:
d = r["TrollStats"].dict()
else:
d = {"modified_date": utcnow()}
for k in [
"user_id",
"username",
"auth_method",
"display_name",
"last_activity_date",
]:
d[k] = r[k]
if highlighted_user_id:
d["highlighted"] = r["user_id"] == highlighted_user_id
return TrollScore(**d)
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
@@ -63,7 +102,7 @@ class UserStatsRepository:
UserStats,
)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard, User.enabled)
.order_by(UserStats.rank)
.limit(limit)
)
@@ -82,7 +121,7 @@ class UserStatsRepository:
window_size: int = 5,
) -> LeaderboardStats | None:
# no window for users who don't show themselves
if not user.show_on_leaderboard:
if not user.show_on_leaderboard or not user.enabled:
return None
qry = self.session.query(UserStats).filter(UserStats.user_id == user.id, UserStats.time_frame == time_frame)
@@ -105,7 +144,7 @@ class UserStatsRepository:
UserStats,
)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard)
.filter(UserStats.time_frame == time_frame.value, User.show_on_leaderboard, User.enabled)
.where(UserStats.rank >= min_rank, UserStats.rank <= max_rank)
.order_by(UserStats.rank)
)
@@ -119,7 +158,16 @@ class UserStatsRepository:
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
qry = (
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
self.session.query(
User.id.label("user_id"),
User.username,
User.auth_method,
User.display_name,
User.streak_days,
User.streak_last_day_date,
User.last_activity_date,
UserStats,
)
.outerjoin(UserStats, User.id == UserStats.user_id)
.filter(User.id == user_id)
)
@@ -133,6 +181,38 @@ class UserStatsRepository:
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def get_trollboard(
self,
time_frame: UserStatsTimeFrame,
limit: int = 100,
highlighted_user_id: Optional[UUID] = None,
) -> TrollboardStats:
"""
Get trollboard stats for the specified time frame
"""
qry = (
self.session.query(
User.id.label("user_id"),
User.username,
User.auth_method,
User.display_name,
User.last_activity_date,
TrollStats,
)
.join(TrollStats, User.id == TrollStats.user_id)
.filter(TrollStats.time_frame == time_frame.value)
.order_by(TrollStats.rank)
.limit(limit)
)
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
if len(trollboard) > 0:
last_update = max(x.modified_date for x in trollboard)
else:
last_update = utcnow()
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
@@ -248,9 +328,9 @@ class UserStatsRepository:
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
if mode == LabelTaskMode.simple:
s.labels_simple = count
elif mode == "full":
elif mode == LabelTaskMode.full:
s.labels_full = count
qry = self.query_labels_by_mode_per_user(
@@ -259,9 +339,20 @@ class UserStatsRepository:
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == "simple":
if mode == LabelTaskMode.simple:
s.labels_simple += count
elif mode == "full":
elif mode == LabelTaskMode.full:
s.labels_full += count
qry = self.query_labels_by_mode_per_user(
payload_type=LabelInitialPromptPayload.__name__, reference_time=base_date
)
for r in qry:
uid, mode, count = r
s = get_stats(uid)
if mode == LabelTaskMode.simple:
s.labels_simple += count
elif mode == LabelTaskMode.full:
s.labels_full += count
qry = self.query_rankings_per_user(reference_time=base_date)
@@ -292,10 +383,145 @@ class UserStatsRepository:
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_ranks(time_frame=time_frame)
self.update_leader_ranks(time_frame=time_frame)
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
qry = self.session.query(
Message.user_id,
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(Message.user_id, func.count().label("spam_prompts"))
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(
Message.user_id,
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
"lang_mismach"
),
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
"not_appropriate"
),
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
"sexual_content"
),
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
"political_content"
),
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
)
.select_from(TextLabels)
.join(Message, TextLabels.message_id == Message.id)
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
# gather user data
time_frame_key = time_frame.value
stats_by_user: dict[UUID, TrollStats] = dict()
now = utcnow()
def get_stats(id: UUID) -> TrollStats:
us = stats_by_user.get(id)
if not us:
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
stats_by_user[id] = us
return us
# emoji counts of user's messages
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
s.upvotes = r["up"]
s.downvotes = r["down"]
s.red_flags = r["flag"]
# num spam prompts
qry = self.query_spam_prompts_per_user(reference_time=base_date)
for r in qry:
uid, count = r
s = get_stats(uid).spam_prompts = count
label_field_names = (
"quality",
"humor",
"toxicity",
"violence",
"helpfulness",
"spam",
"lang_mismach",
"not_appropriate",
"pii",
"hate_speech",
"sexual_content",
"political_content",
)
# label counts / mean values
qry = self.query_labels_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
for fn in label_field_names:
setattr(s, fn, r[fn])
# delete all existing stast for time frame
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
self.session.execute(d)
if None in stats_by_user:
logger.warning("Some messages in DB have NULL values in user_id column.")
del stats_by_user[None]
# compute magic leader score
for v in stats_by_user.values():
v.troll_score = v.compute_troll_score()
# insert user objects
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_troll_ranks(time_frame=time_frame)
@log_timing(log_kwargs=True)
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
"""
Update user_stats ranks. The persisted rank values allow to
quickly the rank of a single user and to query nearby users.
@@ -321,7 +547,7 @@ FROM
ORDER BY leader_score DESC, user_id
) AS "rank", user_id, time_frame
FROM user_stats us2
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard
INNER JOIN "user" u ON us2.user_id = u.id AND u.show_on_leaderboard AND u.enabled
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
us.user_id = r.user_id
@@ -329,10 +555,41 @@ WHERE
r = self.session.execute(
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame, reference_time)
@log_timing(log_kwargs=True)
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
sql_update_troll_rank = """
-- update rank
UPDATE troll_stats ts
SET "rank" = r."rank"
FROM
(SELECT
ROW_NUMBER () OVER(
PARTITION BY time_frame
ORDER BY troll_score DESC, user_id
) AS "rank", user_id, time_frame
FROM troll_stats ts2
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
ts.user_id = r.user_id
AND ts.time_frame = r.time_frame;"""
r = self.session.execute(
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(
self,
time_frame: UserStatsTimeFrame,
reference_time: Optional[datetime] = None,
leader_stats: bool = True,
troll_stats: bool = True,
):
if leader_stats:
self._update_stats_internal(time_frame, reference_time)
if troll_stats:
self._update_troll_stats_internal(time_frame, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
+1 -1
View File
@@ -66,7 +66,7 @@ def get_winner(pairs):
def get_ranking(pairs):
"""
Abuses concordance property to get a (not necessarily unqiue) ranking.
Abuses concordance property to get a (not necessarily unique) ranking.
The lack of uniqueness is due to the potential existence of multiple
equally ranked winners. We have to pick one, which is where
the non-uniqueness comes from
+54 -10
View File
@@ -1,12 +1,16 @@
from __future__ import annotations
import contextlib
import gzip
import json
import sys
from collections import defaultdict
from typing import Optional, TextIO
from typing import Iterable, Optional, TextIO
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from oasst_backend.models import Message
from oasst_backend.models.message_tree_state import State as TreeState
from pydantic import BaseModel
@@ -17,6 +21,7 @@ class ExportMessageNode(BaseModel):
role: str
lang: str | None
review_count: int | None
review_result: bool | None
rank: int | None
synthetic: bool | None
model_name: str | None
@@ -32,6 +37,7 @@ class ExportMessageNode(BaseModel):
role=message.role,
lang=message.lang,
review_count=message.review_count,
review_result=message.review_result if message.review_result or message.review_count > 2 else None,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,
@@ -41,10 +47,13 @@ class ExportMessageNode(BaseModel):
class ExportMessageTree(BaseModel):
message_tree_id: str
tree_state: Optional[str]
prompt: Optional[ExportMessageNode]
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
def build_export_tree(
message_tree_id: UUID, message_tree_state: TreeState, messages: list[Message]
) -> ExportMessageTree:
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
messages_by_parent = defaultdict(list)
@@ -59,19 +68,54 @@ def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMe
return node
prompt = assign_replies(messages_by_parent[None][0])
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)
return ExportMessageTree(message_tree_id=str(message_tree_id), tree_state=message_tree_state, prompt=prompt)
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
out_buff: TextIO
if use_compression:
out_buff = gzip.open(file, "wt", encoding="UTF-8")
# see https://stackoverflow.com/questions/17602878/how-to-handle-both-with-open-and-sys-stdout-nicely
@contextlib.contextmanager
def smart_open(filename: str = None) -> TextIO:
if filename and filename != "-":
fh = open(filename, "wt", encoding="UTF-8")
else:
out_buff = open(file, "wt", encoding="UTF-8")
fh = sys.stdout
try:
yield fh
finally:
if fh is not sys.stdout:
fh.close()
def write_trees_to_file(filename: str | None, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
out_buff: TextIO
if use_compression:
if not filename:
raise RuntimeError("File name must be specified when using compression.")
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
else:
out_buff = smart_open(filename)
with out_buff as f:
for tree in trees:
file_data = jsonable_encoder(tree)
file_data = jsonable_encoder(tree, exclude_none=True)
json.dump(file_data, f)
f.write("\n")
def write_messages_to_file(filename: str | None, messages: Iterable[Message], use_compression: bool = True) -> None:
out_buff: TextIO
if use_compression:
if not filename:
raise RuntimeError("File name must be specified when using compression.")
out_buff = gzip.open(filename, "wt", encoding="UTF-8")
else:
out_buff = smart_open(filename)
with out_buff as f:
for m in messages:
export_message = ExportMessageNode.prep_message_export(m)
file_data = jsonable_encoder(export_message, exclude_none=True)
json.dump(file_data, f)
f.write("\n")
+1
View File
@@ -1,3 +1,4 @@
aiohttp==3.8.3
alembic==1.8.1
cryptography==39.0.0
fastapi==0.88.0
+156
View File
@@ -0,0 +1,156 @@
# Collection of SQL Snippets
Here are find some SQL queries to inspect the current OA postgres DB.
# Baics Stats
```sql
-- tables row counts
(select 'user' as "table", count(*) from "user") union
(select 'task', count(*) from task) union
(select 'message_tree_state', count(*) from message_tree_state) union
(select 'message_reaction', count(*) from message_reaction) union
(select 'text_labels', count(*) from text_labels) union
(select 'message', count(*) from message) union
(select 'journal', count(*) from journal);
```
# Messages
```sql
-- only human by role
select role, count(*)
from message
where not deleted and review_result and not synthetic
group by role;
```
```sql
-- language distribution of messages (incl. synthetic)
select lang, count(*), synthetic from message where not deleted and review_result
group by lang, synthetic;
```
```sql
-- only human generated messages by lang
select lang, count(*)
from message
where not deleted and review_result and not synthetic
group by lang;
```
## Message Trees
```sql
-- total count of message trees
select count(*) from message_tree_state;
```
```sql
-- message tree counts by state
select state, count(*) from message_tree_state group by state;
```
```sql
-- count of waiting initial prompts by language
select m.lang, count(*)
from message_tree_state mts
join message m on mts.message_tree_id = m.id
where mts.state = 'prompt_lottery_waiting'
group by m.lang;
```
```sql
-- message trees by lang in ready_for_export or growing state
select m.lang, mts.state, count(*)
from message_tree_state mts
join message m on mts.message_tree_id = m.id
where mts.state in ('ready_for_export', 'growing')
group by mts.state, m.lang
order by lang, state;
```
```sql
-- select message tree counts
select mts.message_tree_id, count(m.id), max(m.depth), count(m.id) filter (where m.role='prompter') as prompter, count(m.id) filter (where m.role='assistant') as assistant
from message_tree_state mts
join message m on mts.message_tree_id = m.message_tree_id
where mts.state='growing'
and not m.deleted
and m.review_result=true
and m.lang='en'
and mts.active
group by mts.message_tree_id
order by count(m.id) desc;
```
```sql
-- show top 100 largest trees
select mts.message_tree_id, mts.goal_tree_size, mts.state, count(m.id) as message_count
from message_tree_state mts
join message m on mts.message_tree_id = m.message_tree_id
where not m.deleted and m.review_result=true
group by mts.message_tree_id, mts.state
order by count(m.id) desc
limit 100;
```
```sql
-- active trees, current & goal_size
select mts.message_tree_id, mts.state, mts.goal_tree_size, count(m.id) AS tree_size, max(m.depth) AS max_depth
from message_tree_state mts
join message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active
and not m.deleted
and m.review_result
group by mts.message_tree_id, mts.goal_tree_size;
```
## Users
```sql
-- count users that accepted tos
select count(*) from "user" where tos_acceptance_date is not null;
```
```sql
-- last 25 active users
select u.id, u.username, u.auth_method, u.display_name, u.last_activity_date, age(current_timestamp, last_activity_date) from "user" u WHERE u.last_activity_date is not null order by u.last_activity_date desc limit 25;
select id, display_name, username, auth_method, last_activity_date from "user" where age(last_activity_date) < interval '1 minutes' order by last_activity_date desc limit 25;
```
```sql
-- count active users in last 5 mins
select count(*) from "user" u where age(current_timestamp, last_activity_date) < interval '5 mins';
```
```sql
-- total count of non-deleted messages (human + synth)
select count(*) from message where deleted=false and review_result=true;
```
```sql
-- count max, mean message counts per tree for a given language
with t(message_tree_id, tree_size, state) as (select mts.message_tree_id, count(m.id), mts.state
from message_tree_state mts
join message m on mts.message_tree_id = m.message_tree_id
where
not m.deleted
and m.review_result=true
and m.lang = 'en'
group by mts.message_tree_id)
select state, count(t.*) as trees, sum(t.tree_size) as total_msgs, max(t.tree_size), avg(t.tree_size) from t group by t.state;
```
## Connections
```sql
-- from https://dba.stackexchange.com/questions/161760/number-of-active-connections-and-remaining-connections
select max_conn,used,res_for_super,max_conn-used-res_for_super res_for_normal
from
(select count(*) used from pg_stat_activity) t1,
(select setting::int res_for_super from pg_settings where name=$$superuser_reserved_connections$$) t2,
(select setting::int max_conn from pg_settings where name=$$max_connections$$) t3;
```
+1 -1
View File
@@ -29,7 +29,7 @@ This will create a variety of aws roles and services needed for deployment.
copilot deploy
```
This will depoy the services but it won't be 100% ready for usage. Before being
This will deploy the services but it won't be 100% ready for usage. Before being
ready, we have to inspect the AWS Secrets manager and extract out the database
credentials. Read those credentials then put them, and a few other secrets, in a
`secrets.yml` file like the following:
+1 -1
View File
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY=""
OASST_API_KEY="1234"
+388 -391
View File
@@ -9,27 +9,25 @@ import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.messages import (
assistant_reply_message,
assistant_reply_messages,
confirm_label_response_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_message,
invalid_user_input_embed,
label_assistant_reply_message,
label_initial_prompt_message,
label_prompter_reply_message,
initial_prompt_messages,
label_assistant_reply_messages,
label_prompter_reply_messages,
plain_embed,
prompter_reply_message,
prompter_reply_messages,
rank_assistant_reply_message,
rank_initial_prompts_message,
rank_prompter_reply_message,
rank_conversation_reply_messages,
rank_initial_prompts_messages,
rank_prompter_reply_messages,
task_complete_embed,
)
from bot.settings import Settings
from loguru import logger
from oasst_shared.api_client import OasstApiClient, TaskType
from oasst_shared.api_client import OasstApiClient
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin")
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
settings = Settings()
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
class _TaskHandler(t.Generic[_Task_contra]):
"""Handle user interaction for a task."""
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
"""Create a new `TaskHandler`.
Args:
ctx (lightbulb.Context): The context of the command that started the task.
task (_Task_contra): The task to handle.
"""
self.ctx = ctx
self.task = task
self.task_messages = self.get_task_messages(task)
self.sent_messages: list[hikari.Message] = []
@staticmethod
def get_task_messages(task: _Task_contra) -> list[str]:
"""Get the messages to send to the user for the task."""
raise NotImplementedError
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
"""Send the task and wait for the user to accept/skip/cancel it."""
# Send all but the last message because we need to attach buttons to the last one
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
for task_msg in self.task_messages[:-1]:
if len(task_msg) > 2000:
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
task_msg = task_msg[:1999]
self.sent_messages.append(await self.ctx.author.send(task_msg))
# Send the last message with buttons
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
await task_accept_view.start(last_msg)
await task_accept_view.wait()
return task_accept_view.choice
async def handle(self) -> None:
"""Handle the user's response to the task.
This method should be called after `send` has been called."""
# Ack task to the backend
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
# Loop until the user's input is accepted
while True:
try:
# Wait for user to send a message
event = await self.ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
predicate=lambda e: (
e.author_id == self.ctx.author.id
and e.message.content is not None
and not e.message.content.startswith(settings.prefix)
),
timeout=MAX_TASK_TIME,
)
# Validate the message
if event.content is None or not self.check_user_input(event.content):
await self.ctx.author.send("Invalid input")
continue
# Confirm user input
if not (await self.confirm_user_input(event.content)):
continue
# Message is valid and confirmed by user
break
except asyncio.TimeoutError:
return
next_task = await self.notify(event.content, event)
if not isinstance(next_task, protocol_schema.TaskDone):
raise TypeError(f"Unknown task type: {next_task!r}")
return
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
"""Notify the backend that the user completed the task."""
raise NotImplementedError
async def confirm_user_input(self, content: str) -> bool:
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
raise NotImplementedError
def check_user_input(self, content: str) -> bool:
"""Check the user's response to the task. Returns True if the response is valid."""
raise NotImplementedError
async def cancel(self, reason: str = "not specified") -> None:
"""Cancel the task."""
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.nack_task(self.task.id, reason)
_Ranking_contra = t.TypeVar(
"_Ranking_contra",
bound=protocol_schema.RankAssistantRepliesTask
| protocol_schema.RankInitialPromptsTask
| protocol_schema.RankPrompterRepliesTask
| protocol_schema.RankConversationRepliesTask,
contravariant=True,
)
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
"""This should not be used directly. Use its subclasses instead."""
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
task = await oasst_api.post_interaction(
protocol_schema.MessageRanking(
user=protocol_schema.User(
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
),
ranking=[int(r) - 1 for r in content.split(",")],
message_id=f"{self.sent_messages[0].id}",
)
)
db: Connection = self.ctx.bot.d.db
async with db.cursor() as cursor:
row = await (
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
).fetchone()
log_channel = row[0] if row else None
log_messages: list[hikari.Message] = []
if log_channel is not None:
for message in self.task_messages[:-1]:
msg = await self.ctx.bot.rest.create_message(log_channel, message)
log_messages.append(msg)
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
return task
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
return rank_assistant_reply_message(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
super().__init__(ctx, task)
@staticmethod
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
return rank_initial_prompts_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.prompt_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
return rank_prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
return rank_conversation_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
@staticmethod
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
return initial_prompt_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
return prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
return assistant_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
def check_user_input(self, content: str) -> bool:
user_labels = content.split(",")
return (
all([l in self.task.valid_labels for l in user_labels])
and self.task.mandatory_labels is not None
and all([m in user_labels for m in self.task.mandatory_labels])
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
return label_assistant_reply_messages(task)
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
return label_prompter_reply_messages(task)
summarize_story = "summarize_story"
rate_summary = "rate_summary"
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.random),
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work(ctx: lightbulb.Context):
"""Create and handle a task."""
# Only send this message if started from a server
if ctx.guild_id is not None:
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
currently_working: dict[
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
] = ctx.bot.d.currently_working
async def work2(ctx: lightbulb.Context) -> None:
"""Complete a task."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
# Check if the user is already working on a task
if ctx.author.id in currently_working:
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
case False | None:
return
case True:
old_msg, task_id = currently_working[ctx.author.id]
if old_msg is not None:
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
map(lambda c: c, old_msg.components)
await old_msg.delete()
if task_id is not None:
await oasst_api.nack_task(task_id, reason="user cancelled")
task_id = currently_working[ctx.author.id]
await oasst_api.nack_task(task_id, reason="user cancelled")
await msg.delete()
if ctx.guild_id:
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
currently_working[ctx.author.id] = (None, None)
# Create a TaskRequestType from the stringified enum value
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
logger.debug(f"Starting task_type: {task_type!r}")
# Keep sending tasks until the user doesn't want more
try:
await _handle_task(ctx, task_type)
while True:
task = await oasst_api.fetch_random_task(
user=protocol_schema.User(
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
),
)
# Ranking tasks
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
task_handler = RankAssistantRepliesHandler(ctx, task)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
task_handler = RankInitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
task_handler = RankPrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
task_handler = RankConversationReplyHandler(ctx, task)
# Text input tasks
elif isinstance(task, protocol_schema.InitialPromptTask):
task_handler = InitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.PrompterReplyTask):
task_handler = PrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.AssistantReplyTask):
task_handler = AssistantReplyHandler(ctx, task)
# Label tasks
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
task_handler = LabelAssistantReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
task_handler = LabelPrompterReplyHandler(ctx, task)
else:
raise ValueError(f"Unknown task type: {type(task)}")
resp = await task_handler.send()
match resp:
case "accept":
currently_working[ctx.author.id] = task.id
await task_handler.handle()
case "next":
await task_handler.cancel("user skipped task")
case "cancel":
await task_handler.cancel("user canceled work")
break
case None:
await task_handler.cancel("select timed out")
break
finally:
del currently_working[ctx.author.id]
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
# User cancelled
return
# Task action loop
completed = False
while not completed:
await ctx.author.send(embed=plain_embed("Please type your response below:"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id
and not (e.message.content or "").startswith(settings.prefix),
)
except asyncio.TimeoutError:
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
valid, err_msg = _validate_user_input(event.content, task)
if not valid or event.content is None:
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
continue
logger.debug(f"Successful user input received: {event.content}")
# Confirm user input
if isinstance(task, protocol_schema.RankConversationRepliesTask):
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
content = confirm_label_response_message(event.content)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
logger.critical(f"Unknown task type: {task.type}")
raise ValueError(f"Unknown task type: {task.type}")
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
msg = await ctx.author.send(content, components=confirm_resp_view)
await confirm_resp_view.start(msg)
await confirm_resp_view.wait()
match confirm_resp_view.choice:
case False | None:
continue
case True:
await msg.delete() # buttons are already gone
# Send the response to the backend
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
reply = protocol_schema.MessageRanking(
message_id=str(msg_id),
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
labels = event.content.replace(" ", "").split(",")
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
reply = protocol_schema.TextLabels(
message_id=task.message_id,
labels=labels_dict,
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
else:
logger.critical(f"Unexpected task type received: {task.type}")
raise ValueError(f"Unexpected task type received: {task.type}")
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send(embed=plain_embed("Task completed"))
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in all the log channels that the task is complete
conn: Connection = ctx.bot.d.db
async with conn.cursor() as cursor:
await cursor.execute("SELECT log_channel_id FROM guild_settings")
log_channel_ids = await cursor.fetchall()
channels = [
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
for id in log_channel_ids
]
done_embed = task_complete_embed(task, ctx.author.mention)
# This will definitely get the bot rate limited, but that's a future problem
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
# ask the user if they want to do another task
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
await another_task_view.start(msg)
await another_task_view.wait()
match another_task_view.choice:
case False | None:
done = True
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
case True:
pass
async def _select_task(
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg = await _send_task(ctx, task, msg)
msg_id = str(msg.id)
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
return None, msg_id
async def _send_task(
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
"""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
content = initial_prompt_message(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
content = rank_initial_prompts_message(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
content = rank_prompter_reply_message(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.label_initial_prompt:
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
logger.debug("sending label initial prompt task")
content = label_initial_prompt_message(task)
elif task.type == TaskRequestType.label_prompter_reply:
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
logger.debug("sending label prompter reply task")
content = label_prompter_reply_message(task)
elif task.type == TaskRequestType.label_assistant_reply:
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
logger.debug("sending label assistant reply task")
content = label_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
content = prompter_reply_message(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
content = assistant_reply_message(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
if not msg:
msg = await ctx.author.send(
content,
embed=embed,
components=view,
)
else:
await msg.edit(
content,
embed=embed,
components=view,
)
assert msg is not None
# Set the choice id as the current msg id
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
await view.start(msg)
await view.wait()
return view.choice, msg
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
"""Returns whether the user's input is valid for the task type and an error message."""
if content is None:
return False, "No input provided"
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0, "Message must be at least one character long."
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
"Message must contain numbers for all replies.",
)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
"Message must contain numbers for all prompts.",
)
# Labels tasks
elif task.type in (
TaskRequestType.label_initial_prompt,
TaskRequestType.label_prompter_reply,
TaskRequestType.label_assistant_reply,
):
assert isinstance(
task,
protocol_schema.LabelInitialPromptTask
| protocol_schema.LabelPrompterReplyTask
| protocol_schema.LabelAssistantReplyTask,
)
labels = content.replace(" ", "").split(",")
valid_labels = set(task.valid_labels)
return (
set(labels).issubset(valid_labels),
"Message must only contain labels from predefined set of labels.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
+129 -69
View File
@@ -1,4 +1,11 @@
"""All user-facing messages and embeds."""
"""All user-facing messages and embeds.
When sending a conversation
- The function will return a list of strings
- use asyncio.gather to send all messages
-
"""
from datetime import datetime
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _label_prompt(text: str) -> str:
return f":question: _{text}"
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
return f""":question: _{text}_
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
Valid labels: {", ".join(valid_labels)}
"""
def _response_prompt(text: str) -> str:
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
"""
def _make_ordered_list(items: list[str]) -> list[str]:
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[str]) -> str:
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> str:
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
def _hint(hint: str | None) -> str:
return f"{NL}Hint: {hint}" if hint else ""
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
messages = map(
lambda m: f"""\
:robot: __Assistant__:
{m.text}
"""
if m.is_assistant
else f"""\
:person_red_hair: __User__:
{m.text}
""",
conv.messages,
)
return list(messages)
def _li(text: str) -> str:
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
###
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
return f"""\
return [
f"""\
{_h1("INITIAL PROMPT")}
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
{_writing_prompt("Please provide an initial prompt to the assistant.")}
{_hint(task.hint)}
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
"""
]
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return f"""\
return [
f"""\
{_h1("RANK INITIAL PROMPTS")}
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
{_ordered_list(task.prompts)}
{_ordered_list(task.prompt_messages)}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
"""
]
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return f"""\
return [
"""\
{_h1("RANK PROMPTER REPLIES")}
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)}
{_user(None)}
{_ordered_list(task.replies)}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return f"""\
return [
"""\
{_h1("RANK ASSISTANT REPLIES")}
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":robot: __Assistant__:,
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)}
{_assistant(None)}
{_ordered_list(task.replies)}
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
return [
"""\
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
""",
]
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
{task.prompt}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
"""
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
return f"""\
return [
f"""\
{_h1("LABEL PROMPTER REPLY")}
{_conversation(task.conversation)}
{_user(None)}
""",
*_conversation(task.conversation),
f"""{_user(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
return f"""\
return [
f"""\
{_h1("LABEL ASSISTANT REPLY")}
{_conversation(task.conversation)}
""",
*_conversation(task.conversation),
f"""
{_assistant(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\
return [
"""\
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
{_h1("PROMPTER REPLY")}
""",
*_conversation(task.conversation),
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
:speech_balloon: _Please provide a reply to the assistant._
""",
]
{_conversation(task.conversation)}
{_hint(task.hint)}
{_response_prompt("Please provide a reply to the assistant.")}
"""
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
# return [
# message_templates.render("title.msg", "PROMPTER REPLY"),
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
# message_templates.render("prompter_reply_task.msg", task.hint),
# ]
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
return f"""\
{_h1("ASSISTANT REPLY")}
return [
"""\
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
"""\
{_conversation(task.conversation)}
{_response_prompt("Please provide an assistant reply to the prompter.")}
"""
:speech_balloon: _Please provide a reply to the user as the assistant._
""",
]
def confirm_text_response_message(content: str) -> str:
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
"""
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
+62
View File
@@ -125,9 +125,71 @@ services:
- EMAIL_FROM=info@example.com
- NEXTAUTH_URL=http://localhost:3000
- DEBUG_LOGIN=true
- NEXT_PUBLIC_CLOUDFARE_CAPTCHA_SITE_KEY=1x00000000000000000000AA
- CLOUDFLARE_CAPTCHA_SECRET_KEY=1x0000000000000000000000000000000AA
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA=true
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN=true
depends_on:
webdb:
condition: service_healthy
ports:
- "3000:3000"
command: bash wait-for-postgres.sh node server.js
inference-server:
build:
dockerfile: docker/inference/Dockerfile.server
context: .
target: dev
image: oasst-inference-server:dev
environment:
- "PORT=8000"
- "REDIS_HOST=redis"
volumes:
- "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/server:/opt/inference/server"
restart: unless-stopped
depends_on:
redis:
condition: service_healthy
profiles: ["inference"]
inference-worker:
build:
dockerfile: docker/inference/Dockerfile.worker
context: .
target: dev
image: oasst-inference-worker:dev
environment:
- "BACKEND_URL=ws://inference-server:8000"
- "INFERENCE_SERVER_URL=http://inference-text-generation-server"
volumes:
- "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/worker:/opt/inference/worker"
depends_on:
- inference-server
deploy:
replicas: 1
profiles: ["inference"]
inference-text-client:
build:
dockerfile: docker/inference/Dockerfile.text-client
context: .
image: oasst-inference-text-client
environment:
- "BACKEND_URL=http://inference-server:8000"
tty: true
stdin_open: true
volumes:
- "./inference/worker:/opt/inference/worker"
restart: unless-stopped
depends_on:
- inference-server
profiles: ["inference"]
inference-text-generation-server:
image: ghcr.io/huggingface/text-generation-inference
environment:
- "MODEL_ID=distilgpt2"
profiles: ["inference"]
+90
View File
@@ -0,0 +1,90 @@
# syntax=docker/dockerfile:1
ARG MODULE="inference"
ARG SERVICE="server"
ARG APP_USER="${MODULE}-${SERVICE}"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
FROM python:3-slim as build
ARG APP_RELATIVE_PATH
WORKDIR /build
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt
FROM python:3.10-alpine3.17 as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ARG MODULE
ARG SERVICE
ENV APP_BASE="/opt/${MODULE}"
ENV APP_ROOT="${APP_BASE}/${SERVICE}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV SHARED_LIBS_BASE="${APP_BASE}/lib"
ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
ENV PORT="8000"
RUN adduser \
--disabled-password \
--no-create-home \
"${APP_USER}"
USER ${APP_USER}
WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
FROM base-env as dev
ARG APP_USER
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared ${SHARED_LIBS_BASE}/oasst-shared
USER root
RUN --mount=type=cache,target=/var/cache/pip,from=build \
pip install \
--cache-dir=/var/cache/pip \
-e "${SHARED_LIBS_BASE}/oasst-shared"
USER ${APP_USER}
VOLUME [ "${APP_BASE}/lib/oasst-shared" ]
CMD uvicorn main:app --reload --host 0.0.0.0 --port "${PORT}"
FROM base-env as prod
ARG APP_USER
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
RUN --mount=type=cache,target=/var/cache/pip,from=dev \
pip install \
--cache-dir=/var/cache/pip \
--target="${APP_LIBS}" \
/tmp/lib/oasst-shared
CMD uvicorn main:app --host 0.0.0.0 --port "${PORT}"
+50
View File
@@ -0,0 +1,50 @@
# syntax=docker/dockerfile:1
ARG APP_USER="text-client"
ARG APP_RELATIVE_PATH="inference/text-client"
FROM python:3.10-alpine3.17 as build
ARG APP_RELATIVE_PATH
WORKDIR /build
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt
FROM python:3.10-alpine3.17 as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
RUN adduser \
--disabled-password \
--no-create-home \
"${APP_USER}"
USER ${APP_USER}
WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
FROM base-env as prod
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
+85
View File
@@ -0,0 +1,85 @@
# syntax=docker/dockerfile:1
ARG MODULE="inference"
ARG SERVICE="worker"
ARG APP_USER="${MODULE}-${SERVICE}"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
FROM python:3.10-alpine3.17 as build
ARG APP_RELATIVE_PATH
WORKDIR /build
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt
FROM python:3.10-alpine3.17 as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ARG MODULE
ARG SERVICE
ENV APP_BASE="/opt/${MODULE}"
ENV APP_ROOT="${APP_BASE}/${SERVICE}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV SHARED_LIBS_BASE="${APP_BASE}/lib"
ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
RUN adduser \
--disabled-password \
--no-create-home \
"${APP_USER}"
USER ${APP_USER}
WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
FROM base-env as dev
ARG APP_USER
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared ${SHARED_LIBS_BASE}/oasst-shared
USER root
RUN --mount=type=cache,target=/var/cache/pip,from=build \
pip install \
--cache-dir=/var/cache/pip \
-e "${SHARED_LIBS_BASE}/oasst-shared"
USER ${APP_USER}
VOLUME [ "${APP_BASE}/lib/oasst-shared" ]
FROM base-env as prod
ARG APP_USER
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
RUN --mount=type=cache,target=/var/cache/pip,from=dev \
pip install \
--cache-dir=/var/cache/pip \
--target="${APP_LIBS}" \
/tmp/lib/oasst-shared
@@ -0,0 +1,25 @@
---
title: We Need Your Help!
description: We Need Your Help!
slug: we-need-your-help
authors: yk
tags: [open-assistant]
image: https://img.youtube.com/vi/64Izfm24FKA/0.jpg
---
import ReactPlayer from "react-player";
We Need Your Help!
Help us collect data for OpenAssistant, the largest and most open alternative to
ChatGPT.
https://open-assistant.io
<ReactPlayer
controls
width="100%"
url="https://www.youtube.com/embed/64Izfm24FKA"
/>
<!--truncate-->
+5
View File
@@ -0,0 +1,5 @@
yk:
name: Yannic Kilcher
title: Project Lead
url: https://www.ykilcher.com/
image_url: https://www.ykilcher.com/_next/image?url=%2F_next%2Fstatic%2Fmedia%2Fheadshot.ff3a7ee3.webp&w=3840&q=75
+1 -1
View File
@@ -182,7 +182,7 @@ message GenerationExample {
class RankingExample:
thread: Thread # The conversation thread before the message to be ranked
messages: list[Message] # The messages to be ranked, in oder of decreasing preference
messages: list[Message] # The messages to be ranked, in order of decreasing preference
```
+5 -5
View File
@@ -14,16 +14,16 @@ help.
There are two large-scale projects in the area of instruction-following /
multitask learning: Promptsource and Natural Instructions - these projects
crowdsourced templates and turned existing NLP datasets into
instruction-following seq2seq form in natural langauge. They include both
instruction-following seq2seq form in natural language. They include both
long-output training examples like generating a sentence that is a likely
consequence of sentence in the prompt, and short-output, like rating prediction
from review. (Pre-)training on such datasets should help model understand and
follow instructions and teach it many abilities neccessary to perform a large
set of tasks correctly. However, these data are not dialog-like - they do not
look like a normal conversation.
follow instructions and teach it many abilities necessary to perform a large set
of tasks correctly. However, these data are not dialog-like - they do not look
like a normal conversation.
There are also supervised dialog datasets such as Blended Skill Talk or SODA. In
constrast to instruction-following datasets, dialog data is not as focused on
contrast to instruction-following datasets, dialog data is not as focused on
"academic tasks" or correctness, but encourage the model to respond naturally
like a person would.
+146 -1
View File
@@ -1,3 +1,107 @@
## Questions about the project
### How far along is this project?
We are in the early stages of development, working from established research in
applying RLHF to large language models.
### Can I install Open Assistant locally and chat with it?
The project is not at that stage yet. See
[the plan](https://github.com/LAION-AI/Open-Assistant#the-plan).
### What is the Docker command for?
Only for local development. It does not launch an AI model.
### Is an AI model ready to test yet?
Not yet. The data you help us collect now through
[https://open-assistant.io/](https://open-assistant.io/) will be used to improve
it.
### What license does Open Assistant use?
The code and models are licensed under the Apache 2.0 license.
### Is the model open?
The model will be open. Some very early prototype models are published on
HuggingFace. Follow the discussion in the Discord channel
[#ml-models-demo](https://discord.com/channels/1055935572465700980/1067096888530178048).
### Which base model will be used?
It's still being discussed. Options include Pythia, GPT-J, and a bunch more..
You can follow the discussion in the Discord channel
[#data-discussion](https://discord.com/channels/1055935572465700980/1058348535612985394).
### Can I download the data?
You will be able to, under CC BY 4.0, but it's not released yet. We want to
remove spam and PII before releasing it.
### Who is behind Open Assistant?
Open Assistant is a project organized by [LAION](https://laion.ai/) and
individuals around the world interested in bringing this technology to everyone.
### Will Open Assistant be free?
Yes, Open Assistant will be free to use and modify.
### What hardware will be required to run the models?
There will be versions which will be runnable on consumer hardware.
### How can I contribute?
If you want to help in the data collection for training the model, go to the
website [https://open-assistant.io/](https://open-assistant.io/). If you want to
contribute code, take a look at the
[tasks in GitHub](https://github.com/orgs/LAION-AI/projects/3) and grab one.
Take a look at this
[contributing guide](https://github.com/GuilleHoardings/Open-Assistant/blob/main/CONTRIBUTING.md).
## Questions about the model training website
### Can I use ChatGPT to help in training Open Assistant, for instance, by generating answers?
No, it is against their terms of service to use it to help train other models.
See
[this issue](https://github.com/LAION-AI/Open-Assistant/issues/471#issuecomment-1374392299).
ChatGPT-like answers will be removed.
### What should I do if I don't know how to complete the task as an assistant?
Skip it.
### Should I fact check the answers by the assistant?
Yes, you should try. If you are not sure, skip the task.
### How can I see my score?
In your [account settings](https://open-assistant.io/account).
### Can we see how many data points have been collected?
There's no public interface for that yet. However, some updates are posted
periodically in
[the #general-discussion Discord channel](https://discord.com/channels/1055935572465700980/1055935573371658252).
Search for `count`.
### How do I write and label prompts?
Check the
[prompting guide](https://projects.laion.ai/Open-Assistant/docs/guides/prompting).
### Where can I report a bug or create a new feature request?
In the [GitHub issues](https://github.com/LAION-AI/Open-Assistant/issues).
## Questions about developing
### Docker-Compose instead of Docker Compose
If you are using `docker-compose` instead of `docker compose` (note the " "
@@ -9,6 +113,33 @@ For more details and information check out
[this SO thread](https://stackoverflow.com/questions/66514436/difference-between-docker-compose-and-docker-compose)
that explains it all in detail.
### Enable Docker's BuildKit Backend
[BuildKit](https://docs.docker.com/build/buildkit/) is Docker's new and improved
builder backend. In addition to being faster and more efficient, it supports
many new features, among which is the ability to provide a persistent cache,
which outlives builds, to compilers and package managers. This is very useful to
speed up consecutive builds, and is used by some container images of
OpenAssistant's stack.
The BuildKit backend is used by
[default by Compose V2](https://www.docker.com/blog/announcing-compose-v2-general-availability/)
(see above). <br/> But if you want to build an image with `docker build` instead
of `docker compose build`, you might need to enable BuildKit.
To do so, just add `DOCKER_BUILDKIT=1` to your environment.
For instance:
```shell
export DOCKER_BUILDKIT=1
```
You could also, more conveniently,
[enable BuildKit by default](https://docs.docker.com/build/buildkit/#:~:text=To%20enable%20docker%20BuildKit%20by%20default),
or use
[Docker Buildx](https://docs.docker.com/build/#:~:text=The%20new%20client%20Docker%20Buildx).
### Pre-commit
We are using pre-commit to ensure the quality of the code as well as the same
@@ -28,7 +159,7 @@ So from now on, in your next commits it will run the `pre-commit` on the files
that have been staged. If there has been any error, you will need to solve that,
and then stage+commit again the changes.
## Docker Cannot Start Container: Permission Denied
### Docker Cannot Start Container: Permission Denied
Instead of running docker with the root command always, you could create a
`docker` group with granted permissions (root):
@@ -63,3 +194,17 @@ getting permission denied (using root user), you can try the following:
# And remove the container
docker rm -f <container id>
```
### Docker Port Problems
Oftentimes people already have some Postgres instance running on the dev
machine. To avoid port problems, change the ports in the `docker-compose.yml` to
ones excluding `5433`, like:
1. Change `db.ports` to `- 5431:5431`.
2. Add `POSTGRES_PORT: 5431` to `db.environment`
3. Change `webdb.ports` to `- 5432:5431`
4. Add `POSTGRES_PORT: 5431` to `db.environment`
5. Add `- POSTGRES_PORT=5432` to `backend.environment`
6. Change `web.environment.DATABASE_URL` to
`postgres://postgres:postgres@webdb:5432/oasst_web`
+1 -1
View File
@@ -1,3 +1,3 @@
# Guides
Useful guides.
Useful guides to using [Open-Assistant](https://open-assistant.io/).
+83
View File
@@ -0,0 +1,83 @@
# Data Collection Guide
## Writing Prompts (Initial Prompts or Prompter Replies)
Attempt to post a **diverse range** of prompts. When posting a prompt, consider
whether it is likely to be original. Collecting the same prompt many times is
usually not useful.
Prompts can be on any topic but should not be obscene or hateful. Prompts go
through a moderation process just like answers, so spammy or inappropriate
prompts will be removed.
Do **not** include personally identifiable information (PII) in prompts. This
will be moderated out.
Prompts should not include links which the assistant must open to effectively
answer the prompt.
### Prompts: Spelling and Grammar
For prompts, grammar and spelling is significantly less important than answers.
If the grammar of a prompt makes it **ambiguous or unreadable** it should be
considered low quality, but if it simply contains common typos this is not a
problem.
## Writing Answers (Assistant Replies)
Answers should aim to match the tone of the prompt. Prompts which are to the
point should receive answers which are to the point. Answers should not be
unnecessarily formal and should be **conversational** in a manner a human may
write. Avoid jargon if possible.
Attempt to answer in a way which accurately fulfils the request in the prompt.
Low quality answers will be moderated. Longer answers tend to be preferable
_unless_ the prompt is sufficiently simple to justify a shorter one.
Where possible, **avoid bias** or expressing subjective preferences in answers.
Do **not** write answers which encourage illegal activities or activities likely
to be harmful to the user or others around them.
### Answers: ChatGPT
**Do not copy answers from ChatGPT** or other OpenAI services. It is prohibited
according to the OpenAI Terms of Service. Obvious ChatGPT-generated answers
should be reported in moderation and rated low quality. Other AI systems may
allow such usage, but answers should still **not** be copied and pasted.
### Answers: Spelling and Grammar
Answers should be held to a high standard of grammar and spelling. It **is**
permissible for answers to contain sources or references provided they are
relevant and from reputable websites. However, this is **not** required.
## Ranking and Classifying Prompts
Largely the criteria for "quality" can be derived by how well the prompt follows
the guidance
[above](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/guides#writing-prompts-initial-prompts-or-prompter-replies).
The other labels are largely self-explanatory.
Spam should be reserved for intentionally low effort or low quality prompts.
It is usually fine to go closer to "Ordinary" than "Creative" unless there is
obvious creativity shown. The same goes for "Serious" compared to "Humorous".
## Ranking and Classifying Answers
Your rankings and classification will help decide which answers are the best and
therefore which ones are prioritised in the final dataset. That means the answer
classification is a crucial task, serving as community moderation.
Largely the criteria for "quality" can be derived by how well the prompt follows
the guidance
[above](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/guides#writing-answers-assistant-replies).
The other labels are largely self-explanatory.
Spam should be reserved for intentionally low effort or low quality answers.
The bad answer label should be given when the answer does not respond to the
prompt at all or is irrelevant.
It is usually fine to go closer to "Ordinary" than "Creative" unless there is
obvious creativity shown. The same goes for "Serious" compared to "Humorous".
+1 -1
View File
@@ -91,7 +91,7 @@ following query:
**Agent**: "I apologize, I'm not sure what you're trying to say. Could you
please rephrase it for me?"
**User**: "You dumb motherf\*&ker AI sh\*t?!"
**User**: "You dumb motherf\*&ker AI sh\*t!"
**Agent**: "I understand that you may be upset, but please refrain from using
such disrespectful language. I'm here to help you. Is there something specific
+22 -1
View File
@@ -41,7 +41,22 @@ const config = {
api: {
path: "docs/api/openapi.json",
},
blog: false,
blog: {
routeBasePath: "/blog",
showReadingTime: true,
blogTitle: "OpenAssistant Blog",
blogDescription: "Home of the OpenAssistant blog.",
blogSidebarTitle: "Blog Posts",
blogSidebarCount: "ALL",
postsPerPage: "ALL",
feedOptions: {
type: "all",
title: "OpenAssistant Blog",
description: "Home of the OpenAssistant blog.",
language: "en",
copyright: `Copyright © ${new Date().getFullYear()} OpenAssistant.`,
},
},
theme: {
customCss: require.resolve("./src/css/custom.css"),
},
@@ -59,12 +74,18 @@ const config = {
src: "img/logo.svg",
},
items: [
{
href: "https://open-assistant.io/",
label: "App",
position: "left",
},
{
type: "doc",
docId: "intro",
position: "left",
label: "Docs",
},
{ to: "/blog", label: "Blog", position: "left" },
{ to: "/api", label: "API", position: "left" },
{
href: "https://github.com/LAION-AI/Open-Assistant",
+4 -3
View File
@@ -15,18 +15,19 @@
"typecheck": "tsc"
},
"dependencies": {
"@docusaurus/core": "2.2.0",
"@docusaurus/preset-classic": "2.2.0",
"@docusaurus/core": "^2.3.1",
"@docusaurus/preset-classic": "^2.3.1",
"@mdx-js/react": "^1.6.22",
"clsx": "^1.2.1",
"docusaurus-preset-openapi": "^0.6.3",
"prism-react-renderer": "^1.3.5",
"react": "^17.0.2",
"react-dom": "^17.0.2",
"react-player": "^2.11.0",
"url": "^0.11.0"
},
"devDependencies": {
"@docusaurus/module-type-aliases": "2.2.0",
"@docusaurus/module-type-aliases": "^2.3.1",
"@tsconfig/docusaurus": "^1.0.5",
"typescript": "^4.7.4"
},
+440 -2
View File
@@ -1275,6 +1275,83 @@
webpack-merge "^5.8.0"
webpackbar "^5.0.2"
"@docusaurus/core@2.3.1", "@docusaurus/core@^2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/core/-/core-2.3.1.tgz#32849f2ffd2f086a4e55739af8c4195c5eb386f2"
integrity sha512-0Jd4jtizqnRAr7svWaBbbrCCN8mzBNd2xFLoT/IM7bGfFie5y58oz97KzXliwiLY3zWjqMXjQcuP1a5VgCv2JA==
dependencies:
"@babel/core" "^7.18.6"
"@babel/generator" "^7.18.7"
"@babel/plugin-syntax-dynamic-import" "^7.8.3"
"@babel/plugin-transform-runtime" "^7.18.6"
"@babel/preset-env" "^7.18.6"
"@babel/preset-react" "^7.18.6"
"@babel/preset-typescript" "^7.18.6"
"@babel/runtime" "^7.18.6"
"@babel/runtime-corejs3" "^7.18.6"
"@babel/traverse" "^7.18.8"
"@docusaurus/cssnano-preset" "2.3.1"
"@docusaurus/logger" "2.3.1"
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/react-loadable" "5.5.2"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-common" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
"@slorber/static-site-generator-webpack-plugin" "^4.0.7"
"@svgr/webpack" "^6.2.1"
autoprefixer "^10.4.7"
babel-loader "^8.2.5"
babel-plugin-dynamic-import-node "^2.3.3"
boxen "^6.2.1"
chalk "^4.1.2"
chokidar "^3.5.3"
clean-css "^5.3.0"
cli-table3 "^0.6.2"
combine-promises "^1.1.0"
commander "^5.1.0"
copy-webpack-plugin "^11.0.0"
core-js "^3.23.3"
css-loader "^6.7.1"
css-minimizer-webpack-plugin "^4.0.0"
cssnano "^5.1.12"
del "^6.1.1"
detect-port "^1.3.0"
escape-html "^1.0.3"
eta "^2.0.0"
file-loader "^6.2.0"
fs-extra "^10.1.0"
html-minifier-terser "^6.1.0"
html-tags "^3.2.0"
html-webpack-plugin "^5.5.0"
import-fresh "^3.3.0"
leven "^3.1.0"
lodash "^4.17.21"
mini-css-extract-plugin "^2.6.1"
postcss "^8.4.14"
postcss-loader "^7.0.0"
prompts "^2.4.2"
react-dev-utils "^12.0.1"
react-helmet-async "^1.3.0"
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
react-loadable-ssr-addon-v5-slorber "^1.0.1"
react-router "^5.3.3"
react-router-config "^5.1.1"
react-router-dom "^5.3.3"
rtl-detect "^1.0.4"
semver "^7.3.7"
serve-handler "^6.1.3"
shelljs "^0.8.5"
terser-webpack-plugin "^5.3.3"
tslib "^2.4.0"
update-notifier "^5.1.0"
url-loader "^4.1.1"
wait-on "^6.0.1"
webpack "^5.73.0"
webpack-bundle-analyzer "^4.5.0"
webpack-dev-server "^4.9.3"
webpack-merge "^5.8.0"
webpackbar "^5.0.2"
"@docusaurus/cssnano-preset@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/cssnano-preset/-/cssnano-preset-2.2.0.tgz#fc05044659051ae74ab4482afcf4a9936e81d523"
@@ -1285,6 +1362,16 @@
postcss-sort-media-queries "^4.2.1"
tslib "^2.4.0"
"@docusaurus/cssnano-preset@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/cssnano-preset/-/cssnano-preset-2.3.1.tgz#e042487655e3e062417855e12edb3f6eee8f5ecb"
integrity sha512-7mIhAROES6CY1GmCjR4CZkUfjTL6B3u6rKHK0ChQl2d1IevYXq/k/vFgvOrJfcKxiObpMnE9+X6R2Wt1KqxC6w==
dependencies:
cssnano-preset-advanced "^5.3.8"
postcss "^8.4.14"
postcss-sort-media-queries "^4.2.1"
tslib "^2.4.0"
"@docusaurus/logger@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/logger/-/logger-2.2.0.tgz#ea2f7feda7b8675485933b87f06d9c976d17423f"
@@ -1293,6 +1380,14 @@
chalk "^4.1.2"
tslib "^2.4.0"
"@docusaurus/logger@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/logger/-/logger-2.3.1.tgz#d76aefb452e3734b4e0e645efc6cbfc0aae52869"
integrity sha512-2lAV/olKKVr9qJhfHFCaqBIl8FgYjbUFwgUnX76+cULwQYss+42ZQ3grHGFvI0ocN2X55WcYe64ellQXz7suqg==
dependencies:
chalk "^4.1.2"
tslib "^2.4.0"
"@docusaurus/mdx-loader@2.2.0", "@docusaurus/mdx-loader@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/mdx-loader/-/mdx-loader-2.2.0.tgz#fd558f429e5d9403d284bd4214e54d9768b041a0"
@@ -1316,6 +1411,29 @@
url-loader "^4.1.1"
webpack "^5.73.0"
"@docusaurus/mdx-loader@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/mdx-loader/-/mdx-loader-2.3.1.tgz#7ec6acee5eff0a280e1b399ea4dd690b15a793f7"
integrity sha512-Gzga7OsxQRpt3392K9lv/bW4jGppdLFJh3luKRknCKSAaZrmVkOQv2gvCn8LAOSZ3uRg5No7AgYs/vpL8K94lA==
dependencies:
"@babel/parser" "^7.18.8"
"@babel/traverse" "^7.18.8"
"@docusaurus/logger" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@mdx-js/mdx" "^1.6.22"
escape-html "^1.0.3"
file-loader "^6.2.0"
fs-extra "^10.1.0"
image-size "^1.0.1"
mdast-util-to-string "^2.0.0"
remark-emoji "^2.2.0"
stringify-object "^3.3.0"
tslib "^2.4.0"
unified "^9.2.2"
unist-util-visit "^2.0.3"
url-loader "^4.1.1"
webpack "^5.73.0"
"@docusaurus/module-type-aliases@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/module-type-aliases/-/module-type-aliases-2.2.0.tgz#1e23e54a1bbb6fde1961e4fa395b1b69f4803ba5"
@@ -1330,6 +1448,20 @@
react-helmet-async "*"
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
"@docusaurus/module-type-aliases@2.3.1", "@docusaurus/module-type-aliases@^2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/module-type-aliases/-/module-type-aliases-2.3.1.tgz#986186200818fed999be2e18d6c698eaf4683a33"
integrity sha512-6KkxfAVOJqIUynTRb/tphYCl+co3cP0PlHiMDbi+SzmYxMdgIrwYqH9yAnGSDoN6Jk2ZE/JY/Azs/8LPgKP48A==
dependencies:
"@docusaurus/react-loadable" "5.5.2"
"@docusaurus/types" "2.3.1"
"@types/history" "^4.7.11"
"@types/react" "*"
"@types/react-router-config" "*"
"@types/react-router-dom" "*"
react-helmet-async "*"
react-loadable "npm:@docusaurus/react-loadable@5.5.2"
"@docusaurus/plugin-content-blog@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-blog/-/plugin-content-blog-2.2.0.tgz#dc55982e76771f4e678ac10e26d10e1da2011dc1"
@@ -1352,6 +1484,28 @@
utility-types "^3.10.0"
webpack "^5.73.0"
"@docusaurus/plugin-content-blog@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-blog/-/plugin-content-blog-2.3.1.tgz#236b8ee4f20f7047aa9c285ae77ae36683ad48a3"
integrity sha512-f5LjqX+9WkiLyGiQ41x/KGSJ/9bOjSD8lsVhPvYeUYHCtYpuiDKfhZE07O4EqpHkBx4NQdtQDbp+aptgHSTuiw==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/logger" "2.3.1"
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-common" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
cheerio "^1.0.0-rc.12"
feed "^4.2.2"
fs-extra "^10.1.0"
lodash "^4.17.21"
reading-time "^1.5.0"
tslib "^2.4.0"
unist-util-visit "^2.0.3"
utility-types "^3.10.0"
webpack "^5.73.0"
"@docusaurus/plugin-content-docs@2.2.0", "@docusaurus/plugin-content-docs@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-docs/-/plugin-content-docs-2.2.0.tgz#0fcb85226fcdb80dc1e2d4a36ef442a650dcc84d"
@@ -1374,6 +1528,28 @@
utility-types "^3.10.0"
webpack "^5.73.0"
"@docusaurus/plugin-content-docs@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-docs/-/plugin-content-docs-2.3.1.tgz#feae1555479558a55182f22f8a07acc5e0d7444d"
integrity sha512-DxztTOBEruv7qFxqUtbsqXeNcHqcVEIEe+NQoI1oi2DBmKBhW/o0MIal8lt+9gvmpx3oYtlwmLOOGepxZgJGkw==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/logger" "2.3.1"
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/module-type-aliases" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
"@types/react-router-config" "^5.0.6"
combine-promises "^1.1.0"
fs-extra "^10.1.0"
import-fresh "^3.3.0"
js-yaml "^4.1.0"
lodash "^4.17.21"
tslib "^2.4.0"
utility-types "^3.10.0"
webpack "^5.73.0"
"@docusaurus/plugin-content-pages@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-pages/-/plugin-content-pages-2.2.0.tgz#e3f40408787bbe229545dd50595f87e1393bc3ae"
@@ -1388,6 +1564,20 @@
tslib "^2.4.0"
webpack "^5.73.0"
"@docusaurus/plugin-content-pages@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-content-pages/-/plugin-content-pages-2.3.1.tgz#f534a37862be5b3f2ba5b150458d7527646b6f39"
integrity sha512-E80UL6hvKm5VVw8Ka8YaVDtO6kWWDVUK4fffGvkpQ/AJQDOg99LwOXKujPoICC22nUFTsZ2Hp70XvpezCsFQaA==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
fs-extra "^10.1.0"
tslib "^2.4.0"
webpack "^5.73.0"
"@docusaurus/plugin-debug@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-debug/-/plugin-debug-2.2.0.tgz#b38741d2c492f405fee01ee0ef2e0029cedb689a"
@@ -1400,6 +1590,18 @@
react-json-view "^1.21.3"
tslib "^2.4.0"
"@docusaurus/plugin-debug@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-debug/-/plugin-debug-2.3.1.tgz#26fef904713e148f6dee44957506280f8b7853bb"
integrity sha512-Ujpml1Ppg4geB/2hyu2diWnO49az9U2bxM9Shen7b6qVcyFisNJTkVG2ocvLC7wM1efTJcUhBO6zAku2vKJGMw==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
fs-extra "^10.1.0"
react-json-view "^1.21.3"
tslib "^2.4.0"
"@docusaurus/plugin-google-analytics@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-2.2.0.tgz#63c7137eff5a1208d2059fea04b5207c037d7954"
@@ -1410,6 +1612,16 @@
"@docusaurus/utils-validation" "2.2.0"
tslib "^2.4.0"
"@docusaurus/plugin-google-analytics@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-analytics/-/plugin-google-analytics-2.3.1.tgz#e2e7db4cf6a7063e8ba5e128d4e413f4d6a0c862"
integrity sha512-OHip0GQxKOFU8n7gkt3TM4HOYTXPCFDjqKbMClDD3KaDnyTuMp/Zvd9HSr770lLEscgPWIvzhJByRAClqsUWiQ==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
tslib "^2.4.0"
"@docusaurus/plugin-google-gtag@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-2.2.0.tgz#7b086d169ac5fe9a88aca10ab0fd2bf00c6c6b12"
@@ -1420,6 +1632,26 @@
"@docusaurus/utils-validation" "2.2.0"
tslib "^2.4.0"
"@docusaurus/plugin-google-gtag@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-gtag/-/plugin-google-gtag-2.3.1.tgz#b8da54a60c0a50aca609c3643faef78cb4f247a0"
integrity sha512-uXtDhfu4+Hm+oqWUySr3DNI5cWC/rmP6XJyAk83Heor3dFjZqDwCbkX8yWPywkRiWev3Dk/rVF8lEn0vIGVocA==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
tslib "^2.4.0"
"@docusaurus/plugin-google-tag-manager@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-google-tag-manager/-/plugin-google-tag-manager-2.3.1.tgz#f19bc01cc784fa4734187c5bc637f0574857e15d"
integrity sha512-Ww2BPEYSqg8q8tJdLYPFFM3FMDBCVhEM4UUqKzJaiRMx3NEoly3qqDRAoRDGdIhlC//Rf0iJV9cWAoq2m6k3sw==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
tslib "^2.4.0"
"@docusaurus/plugin-sitemap@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-sitemap/-/plugin-sitemap-2.2.0.tgz#876da60937886032d63143253d420db6a4b34773"
@@ -1435,7 +1667,22 @@
sitemap "^7.1.1"
tslib "^2.4.0"
"@docusaurus/preset-classic@2.2.0", "@docusaurus/preset-classic@^2.0.0":
"@docusaurus/plugin-sitemap@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/plugin-sitemap/-/plugin-sitemap-2.3.1.tgz#f526ab517ca63b7a3460d585876f5952cb908aa0"
integrity sha512-8Yxile/v6QGYV9vgFiYL+8d2N4z4Er3pSHsrD08c5XI8bUXxTppMwjarDUTH/TRTfgAWotRbhJ6WZLyajLpozA==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/logger" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-common" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
fs-extra "^10.1.0"
sitemap "^7.1.1"
tslib "^2.4.0"
"@docusaurus/preset-classic@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/preset-classic/-/preset-classic-2.2.0.tgz#bece5a043eeb74430f7c6c7510000b9c43669eb7"
integrity sha512-yKIWPGNx7BT8v2wjFIWvYrS+nvN04W+UameSFf8lEiJk6pss0kL6SG2MRvyULiI3BDxH+tj6qe02ncpSPGwumg==
@@ -1453,6 +1700,25 @@
"@docusaurus/theme-search-algolia" "2.2.0"
"@docusaurus/types" "2.2.0"
"@docusaurus/preset-classic@^2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/preset-classic/-/preset-classic-2.3.1.tgz#f0193f06093eb55cafef66bd1ad9e0d33198bf95"
integrity sha512-OQ5W0AHyfdUk0IldwJ3BlnZ1EqoJuu2L2BMhqLbqwNWdkmzmSUvlFLH1Pe7CZSQgB2YUUC/DnmjbPKk/qQD0lQ==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/plugin-content-blog" "2.3.1"
"@docusaurus/plugin-content-docs" "2.3.1"
"@docusaurus/plugin-content-pages" "2.3.1"
"@docusaurus/plugin-debug" "2.3.1"
"@docusaurus/plugin-google-analytics" "2.3.1"
"@docusaurus/plugin-google-gtag" "2.3.1"
"@docusaurus/plugin-google-tag-manager" "2.3.1"
"@docusaurus/plugin-sitemap" "2.3.1"
"@docusaurus/theme-classic" "2.3.1"
"@docusaurus/theme-common" "2.3.1"
"@docusaurus/theme-search-algolia" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/react-loadable@5.5.2", "react-loadable@npm:@docusaurus/react-loadable@5.5.2":
version "5.5.2"
resolved "https://registry.yarnpkg.com/@docusaurus/react-loadable/-/react-loadable-5.5.2.tgz#81aae0db81ecafbdaee3651f12804580868fa6ce"
@@ -1492,6 +1758,37 @@
tslib "^2.4.0"
utility-types "^3.10.0"
"@docusaurus/theme-classic@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-classic/-/theme-classic-2.3.1.tgz#8e6e194236e702c0d4e8d7b7cbb6886ae456e598"
integrity sha512-SelSIDvyttb7ZYHj8vEUhqykhAqfOPKk+uP0z85jH72IMC58e7O8DIlcAeBv+CWsLbNIl9/Hcg71X0jazuxJug==
dependencies:
"@docusaurus/core" "2.3.1"
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/module-type-aliases" "2.3.1"
"@docusaurus/plugin-content-blog" "2.3.1"
"@docusaurus/plugin-content-docs" "2.3.1"
"@docusaurus/plugin-content-pages" "2.3.1"
"@docusaurus/theme-common" "2.3.1"
"@docusaurus/theme-translations" "2.3.1"
"@docusaurus/types" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-common" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
"@mdx-js/react" "^1.6.22"
clsx "^1.2.1"
copy-text-to-clipboard "^3.0.1"
infima "0.2.0-alpha.42"
lodash "^4.17.21"
nprogress "^0.2.0"
postcss "^8.4.14"
prism-react-renderer "^1.3.5"
prismjs "^1.28.0"
react-router-dom "^5.3.3"
rtlcss "^3.5.0"
tslib "^2.4.0"
utility-types "^3.10.0"
"@docusaurus/theme-common@2.2.0", "@docusaurus/theme-common@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-common/-/theme-common-2.2.0.tgz#2303498d80448aafdd588b597ce9d6f4cfa930e4"
@@ -1512,6 +1809,27 @@
tslib "^2.4.0"
utility-types "^3.10.0"
"@docusaurus/theme-common@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-common/-/theme-common-2.3.1.tgz#82f52d80226efef8c4418c4eacfc5051aa215f7f"
integrity sha512-RYmYl2OR2biO+yhmW1aS5FyEvnrItPINa+0U2dMxcHpah8reSCjQ9eJGRmAgkZFchV1+aIQzXOI1K7LCW38O0g==
dependencies:
"@docusaurus/mdx-loader" "2.3.1"
"@docusaurus/module-type-aliases" "2.3.1"
"@docusaurus/plugin-content-blog" "2.3.1"
"@docusaurus/plugin-content-docs" "2.3.1"
"@docusaurus/plugin-content-pages" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@types/history" "^4.7.11"
"@types/react" "*"
"@types/react-router-config" "*"
clsx "^1.2.1"
parse-numeric-range "^1.3.0"
prism-react-renderer "^1.3.5"
tslib "^2.4.0"
use-sync-external-store "^1.2.0"
utility-types "^3.10.0"
"@docusaurus/theme-search-algolia@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-search-algolia/-/theme-search-algolia-2.2.0.tgz#77fd9f7a600917e6024fe3ac7fb6cfdf2ce84737"
@@ -1534,6 +1852,28 @@
tslib "^2.4.0"
utility-types "^3.10.0"
"@docusaurus/theme-search-algolia@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-search-algolia/-/theme-search-algolia-2.3.1.tgz#d587b40913119e9287d14670e277b933d8f453f0"
integrity sha512-JdHaRqRuH1X++g5fEMLnq7OtULSGQdrs9AbhcWRQ428ZB8/HOiaN6mj3hzHvcD3DFgu7koIVtWPQnvnN7iwzHA==
dependencies:
"@docsearch/react" "^3.1.1"
"@docusaurus/core" "2.3.1"
"@docusaurus/logger" "2.3.1"
"@docusaurus/plugin-content-docs" "2.3.1"
"@docusaurus/theme-common" "2.3.1"
"@docusaurus/theme-translations" "2.3.1"
"@docusaurus/utils" "2.3.1"
"@docusaurus/utils-validation" "2.3.1"
algoliasearch "^4.13.1"
algoliasearch-helper "^3.10.0"
clsx "^1.2.1"
eta "^2.0.0"
fs-extra "^10.1.0"
lodash "^4.17.21"
tslib "^2.4.0"
utility-types "^3.10.0"
"@docusaurus/theme-translations@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-translations/-/theme-translations-2.2.0.tgz#5fbd4693679806f80c26eeae1381e1f2c23d83e7"
@@ -1542,6 +1882,14 @@
fs-extra "^10.1.0"
tslib "^2.4.0"
"@docusaurus/theme-translations@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/theme-translations/-/theme-translations-2.3.1.tgz#b2b1ecc00a737881b5bfabc19f90b20f0fe02bb3"
integrity sha512-BsBZzAewJabVhoGG1Ij2u4pMS3MPW6gZ6sS4pc+Y7czevRpzxoFNJXRtQDVGe7mOpv/MmRmqg4owDK+lcOTCVQ==
dependencies:
fs-extra "^10.1.0"
tslib "^2.4.0"
"@docusaurus/types@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/types/-/types-2.2.0.tgz#02c577a4041ab7d058a3c214ccb13647e21a9857"
@@ -1556,6 +1904,20 @@
webpack "^5.73.0"
webpack-merge "^5.8.0"
"@docusaurus/types@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/types/-/types-2.3.1.tgz#785ade2e0f4e35e1eb7fb0d04c27d11c3991a2e8"
integrity sha512-PREbIRhTaNNY042qmfSE372Jb7djZt+oVTZkoqHJ8eff8vOIc2zqqDqBVc5BhOfpZGPTrE078yy/torUEZy08A==
dependencies:
"@types/history" "^4.7.11"
"@types/react" "*"
commander "^5.1.0"
joi "^17.6.0"
react-helmet-async "^1.3.0"
utility-types "^3.10.0"
webpack "^5.73.0"
webpack-merge "^5.8.0"
"@docusaurus/utils-common@2.2.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/utils-common/-/utils-common-2.2.0.tgz#a401c1b93a8697dd566baf6ac64f0fdff1641a78"
@@ -1563,6 +1925,13 @@
dependencies:
tslib "^2.4.0"
"@docusaurus/utils-common@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/utils-common/-/utils-common-2.3.1.tgz#1abe66846eb641547e4964d44f3011938e58e50b"
integrity sha512-pVlRpXkdNcxmKNxAaB1ya2hfCEvVsLDp2joeM6K6uv55Oc5nVIqgyYSgSNKZyMdw66NnvMfsu0RBylcwZQKo9A==
dependencies:
tslib "^2.4.0"
"@docusaurus/utils-validation@2.2.0", "@docusaurus/utils-validation@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/utils-validation/-/utils-validation-2.2.0.tgz#04d4d103137ad0145883971d3aa497f4a1315f25"
@@ -1574,6 +1943,17 @@
js-yaml "^4.1.0"
tslib "^2.4.0"
"@docusaurus/utils-validation@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/utils-validation/-/utils-validation-2.3.1.tgz#b65c718ba9b84b7a891bccf5ac6d19b57ee7d887"
integrity sha512-7n0208IG3k1HVTByMHlZoIDjjOFC8sbViHVXJx0r3Q+3Ezrx+VQ1RZ/zjNn6lT+QBCRCXlnlaoJ8ug4HIVgQ3w==
dependencies:
"@docusaurus/logger" "2.3.1"
"@docusaurus/utils" "2.3.1"
joi "^17.6.0"
js-yaml "^4.1.0"
tslib "^2.4.0"
"@docusaurus/utils@2.2.0", "@docusaurus/utils@^2.0.0":
version "2.2.0"
resolved "https://registry.yarnpkg.com/@docusaurus/utils/-/utils-2.2.0.tgz#3d6f9b7a69168d5c92d371bf21c556a4f50d1da6"
@@ -1595,6 +1975,28 @@
url-loader "^4.1.1"
webpack "^5.73.0"
"@docusaurus/utils@2.3.1":
version "2.3.1"
resolved "https://registry.yarnpkg.com/@docusaurus/utils/-/utils-2.3.1.tgz#24b9cae3a23b1e6dc88f95c45722c7e82727b032"
integrity sha512-9WcQROCV0MmrpOQDXDGhtGMd52DHpSFbKLfkyaYumzbTstrbA5pPOtiGtxK1nqUHkiIv8UwexS54p0Vod2I1lg==
dependencies:
"@docusaurus/logger" "2.3.1"
"@svgr/webpack" "^6.2.1"
escape-string-regexp "^4.0.0"
file-loader "^6.2.0"
fs-extra "^10.1.0"
github-slugger "^1.4.0"
globby "^11.1.0"
gray-matter "^4.0.3"
js-yaml "^4.1.0"
lodash "^4.17.21"
micromatch "^4.0.5"
resolve-pathname "^3.0.0"
shelljs "^0.8.5"
tslib "^2.4.0"
url-loader "^4.1.1"
webpack "^5.73.0"
"@faker-js/faker@5.5.3":
version "5.5.3"
resolved "https://registry.yarnpkg.com/@faker-js/faker/-/faker-5.5.3.tgz#18e3af6b8eae7984072bbeb0c0858474d7c4cefe"
@@ -3456,6 +3858,11 @@ deep-extend@^0.6.0:
resolved "https://registry.yarnpkg.com/deep-extend/-/deep-extend-0.6.0.tgz#c4fa7c95404a17a9c3e8ca7e1537312b736330ac"
integrity sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==
deepmerge@^4.0.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.3.0.tgz#65491893ec47756d44719ae520e0e2609233b59b"
integrity sha512-z2wJZXrmeHdvYJp/Ux55wIjqo81G5Bp4c+oELTW+7ar6SogWHajt5a9gO3s3IDaGSAXjDk0vlQKN3rms8ab3og==
deepmerge@^4.2.2:
version "4.2.2"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955"
@@ -3899,6 +4306,11 @@ eta@^1.12.3:
resolved "https://registry.yarnpkg.com/eta/-/eta-1.12.3.tgz#2982d08adfbef39f9fa50e2fbd42d7337e7338b1"
integrity sha512-qHixwbDLtekO/d51Yr4glcaUJCIjGVJyTzuqV4GPlgZo1YpgOKG+avQynErZIYrfM6JIJdtiG2Kox8tbb+DoGg==
eta@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/eta/-/eta-2.0.0.tgz#376865fadebc899e5b6dfce82fae64cbbe47e594"
integrity sha512-NqE7S2VmVwgMS8yBxsH4VgNQjNjLq1gfGU0u9I6Cjh468nPRMoDfGdK9n1p/3Dvsw3ebklDkZsFAnKJ9sefjBA==
etag@~1.8.1:
version "1.8.1"
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
@@ -5254,6 +5666,11 @@ liquid-json@0.3.1:
resolved "https://registry.yarnpkg.com/liquid-json/-/liquid-json-0.3.1.tgz#9155a18136d8a6b2615e5f16f9a2448ab6b50eea"
integrity sha512-wUayTU8MS827Dam6MxgD72Ui+KOSF+u/eIqpatOtjnvgJ0+mnDq33uC2M7J0tPK+upe/DpUAuK4JUU89iBoNKQ==
load-script@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/load-script/-/load-script-1.0.0.tgz#0491939e0bee5643ee494a7e3da3d2bac70c6ca4"
integrity sha512-kPEjMFtZvwL9TaZo0uZ2ml+Ye9HUMmPwbYRJ324qF9tqMejwykJ5ggTyvzmrbBeapCAbk98BSbTeovHEEP1uCA==
loader-runner@^4.2.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/loader-runner/-/loader-runner-4.3.0.tgz#c1b4a163b99f614830353b16755e7149ac2314e1"
@@ -5470,6 +5887,11 @@ memfs@^3.1.2, memfs@^3.4.3:
dependencies:
fs-monkey "^1.0.3"
memoize-one@^5.1.1:
version "5.2.1"
resolved "https://registry.yarnpkg.com/memoize-one/-/memoize-one-5.2.1.tgz#8337aa3c4335581839ec01c3d594090cebe8f00e"
integrity sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==
merge-descriptors@1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/merge-descriptors/-/merge-descriptors-1.0.1.tgz#b00aaa556dd8b44568150ec9d1b953f3f90cbb61"
@@ -6715,7 +7137,7 @@ react-error-overlay@^6.0.11:
resolved "https://registry.yarnpkg.com/react-error-overlay/-/react-error-overlay-6.0.11.tgz#92835de5841c5cf08ba00ddd2d677b6d17ff9adb"
integrity sha512-/6UZ2qgEyH2aqzYZgQPxEnz33NJ2gNsnHA2o5+o4wW9bLM/JYQitNP9xPhsXwC08hMMovfGe/8retsdDsczPRg==
react-fast-compare@^3.2.0:
react-fast-compare@^3.0.1, react-fast-compare@^3.2.0:
version "3.2.0"
resolved "https://registry.yarnpkg.com/react-fast-compare/-/react-fast-compare-3.2.0.tgz#641a9da81b6a6320f270e89724fb45a0b39e43bb"
integrity sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA==
@@ -6768,6 +7190,17 @@ react-magic-dropzone@^1.0.1:
resolved "https://registry.yarnpkg.com/react-magic-dropzone/-/react-magic-dropzone-1.0.1.tgz#bfd25b77b57e7a04aaef0a28910563b707ee54df"
integrity sha512-0BIROPARmXHpk4AS3eWBOsewxoM5ndk2psYP/JmbCq8tz3uR2LIV1XiroZ9PKrmDRMctpW+TvsBCtWasuS8vFA==
react-player@^2.11.0:
version "2.11.0"
resolved "https://registry.yarnpkg.com/react-player/-/react-player-2.11.0.tgz#9afc75314eb915238e8d6615b2891fbe7170aeaa"
integrity sha512-fIrwpuXOBXdEg1FiyV9isKevZOaaIsAAtZy5fcjkQK9Nhmk1I2NXzY/hkPos8V0zb/ZX416LFy8gv7l/1k3a5w==
dependencies:
deepmerge "^4.0.0"
load-script "^1.0.0"
memoize-one "^5.1.1"
prop-types "^15.7.2"
react-fast-compare "^3.0.1"
react-redux@^7.2.0:
version "7.2.9"
resolved "https://registry.yarnpkg.com/react-redux/-/react-redux-7.2.9.tgz#09488fbb9416a4efe3735b7235055442b042481d"
@@ -8091,6 +8524,11 @@ use-latest@^1.2.1:
dependencies:
use-isomorphic-layout-effect "^1.1.1"
use-sync-external-store@^1.2.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/use-sync-external-store/-/use-sync-external-store-1.2.0.tgz#7dbefd6ef3fe4e767a0cf5d7287aacfb5846928a"
integrity sha512-eEgnFxGQ1Ife9bzYs6VLi8/4X6CObHMw9Qr9tPY43iKwsPw8xE8+EFsf/2cFZ5S3esXgpWgtSCtLNS41F+sKPA==
util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"
+54 -4
View File
@@ -1,14 +1,64 @@
# OpenAssitant Inference
# OpenAssistant Inference
Preliminary implementation of the inference engine for OpenAssistant.
## Development Variant 1 (you'll need tmux)
## Development Variant 1 (docker compose)
The services of the inference stack are prefixed with "inference-" in the
[unified compose descriptor](../docker-compose.yaml). <br/> Prior to building
those, please ensure that you have Docker's new
[BuildKit](https://docs.docker.com/build/buildkit/) backend enabled. See the
[FAQ](https://projects.laion.ai/Open-Assistant/docs/faq#enable-dockers-buildkit-backend)
for more info.
To build the services, run:
```shell
docker compose --profile inference build
```
Spin up the stack:
```shell
docker compose --profile inference up -d
```
Tail the logs:
```shell
docker compose logs -f \
inference-server \
inference-worker \
inference-text-client \
inference-text-generation-server
```
Attach to the text-client, and start chatting:
```shell
docker attach open-assistant-inference-text-client-1
```
> **Note:** In the last step, `open-assistant-inference-text-client-1` refers to
> the name of the `text-client` container started in step 2.
> **Note:** The compose file contains the bind mounts enabling you to develop on
> the modules of the inference stack, and the `oasst-shared` package, without
> rebuilding.
> **Note:** You can spin up any number of workers by adjusting the number of
> replicas of the `inference-worker` service to your liking.
> **Note:** Please wait for the `inference-text-generation-server` service to
> output `{"message":"Connected"}` before starting to chat.
## Development Variant 2 (you'll need tmux)
Run `./full-dev-setup.sh` to start the full development setup. Make sure to wait
until the 2nd terminal is ready and says `{"message":"Connected"}` before
entering input into the last terminal.
## Development Variant 2 (you'll need multiple terminals)
## Development Variant 3 (you'll need multiple terminals)
Run a redis container (or use the one of the general docker compose file):
@@ -36,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server
running:
```bash
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference
```
Run the client:
+1 -1
View File
@@ -5,7 +5,7 @@
tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
+37 -14
View File
@@ -39,14 +39,6 @@ redisClient = redis.Redis(
)
class CreateChatRequest(pydantic.BaseModel):
pass
class CreateChatResponse(pydantic.BaseModel):
id: str
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
@@ -57,7 +49,7 @@ class MessageRequest(pydantic.BaseModel):
class TokenResponseEvent(pydantic.BaseModel):
token: str
token: inference.TokenResponse
class MessageRequestState(str, enum.Enum):
@@ -67,30 +59,61 @@ class MessageRequestState(str, enum.Enum):
aborted_by_worker = "aborted_by_worker"
class CreateChatRequest(pydantic.BaseModel):
pass
class ChatListEntry(pydantic.BaseModel):
id: str
class ChatEntry(pydantic.BaseModel):
id: str
conversation: protocol.Conversation
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None
def to_list_entry(self) -> ChatListEntry:
return ChatListEntry(id=self.id)
def to_entry(self) -> ChatEntry:
return ChatEntry(id=self.id, conversation=self.conversation)
# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}
@app.get("/chat")
async def list_chats() -> ListChatsResponse:
"""Lists all chats."""
logger.info("Listing all chats.")
chats = [chat.to_list_entry() for chat in CHATS.values()]
return ListChatsResponse(chats=chats)
@app.post("/chat")
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
chat = DbChatEntry()
CHATS[chat.id] = chat
return CreateChatResponse(id=chat.id)
return ChatListEntry(id=chat.id)
@app.get("/chat/{id}")
async def get_chat(id: str) -> protocol.Conversation:
async def get_chat(id: str) -> ChatEntry:
"""Allows a client to get the current state of a chat."""
return CHATS[id].conversation
return CHATS[id].to_entry()
@app.post("/chat/{id}/message")
@@ -143,7 +166,7 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
chat.conversation.messages.append(
protocol.ConversationMessage(
text="".join([d.token for d in result_data[:-1]]),
text=response_packet.generated_text.text,
is_assistant=True,
)
)
+1 -1
View File
@@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"], end="", flush=True)
print(data["token"]["text"], end="", flush=True)
print()
+31 -6
View File
@@ -54,24 +54,49 @@ def main(
"top_p": work_request.top_p,
"temperature": work_request.temperature,
"seed": work_request.seed,
# "stop": ["\nUser:", "\nAssistant:"], # TODO: make this a bit more workable because it's mutliple tokens
},
},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
try:
response.raise_for_status()
except requests.HTTPError:
logger.exception("Failed to get response from inference server")
logger.error(f"Response: {response.text}")
return
client = sseclient.SSEClient(response)
for event in client.events():
logger.debug(f"Received event: {event}")
data = json.loads(event.data)
if data["is_end"]:
if data["generated_text"]:
break
intermediate = data["event"]
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
ws.send(inference.WorkResponsePacket(is_end=True).json())
token = data["token"]
ws.send(
inference.WorkResponsePacket(
token=inference.TokenResponse(
text=token["text"],
log_prob=token["logprob"],
token_id=token["id"],
)
).json()
)
ws.send(
inference.WorkResponsePacket(
is_end=True,
generated_text=inference.GeneratedTextResponse(
text=data["generated_text"],
),
).json()
)
def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")
try:
raise error
except Exception:
logger.exception("Error in websocket")
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
+1 -1
View File
@@ -1,6 +1,6 @@
Some other reward features we can use
0. Finish classifcation feature
0. Finish classification feature
1. Summaries from human feedback
+1 -1
View File
@@ -14,7 +14,7 @@
[] support additional negative samples generated from other models.
For example we can use galactica-125m to generate a TLDR and assume it was
inferior than the human perference one
inferior than the human preference one
"""
@@ -46,6 +46,7 @@ defaults:
quantization: false
seq2seqmodel: false
poly_eps: 1.0
fuse_gelu: true
galactica-125m:
learning_rate: 5e-5
@@ -23,5 +23,5 @@ Issues and TODO:
- ideally we can update the config yaml and new dataset will be download from
hub
- one possible idea is we upload the trasform format of these dataset to the
- one possible idea is we upload the transform format of these dataset to the
OA hub
@@ -311,7 +311,7 @@ class JokeExplaination(Dataset):
for line in f:
data = json.loads(line)
joke = data["joke"]
explanation = data["explaination"]
explanation = data["explanation"]
self.pairs.append((joke, explanation))
if len(question) > 0 and len(answer) > 0:
@@ -0,0 +1,55 @@
import functools
import torch
from transformers.activations import FastGELUActivation, GELUActivation, NewGELUActivation, QuickGELUActivation
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split("."))
def fuse_gelu(model):
@torch.jit.script
def gelu_fwd(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class _FusedGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.input_tensor = input
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.input_tensor
tmp = gelu_bwd(grad_output, input)
return tmp
class FusedGelu(torch.nn.Module):
def forward(self, input):
return _FusedGeLUFunction.apply(input)
fused_gelu_module = FusedGelu()
hf_gelu_functions = [GELUActivation, FastGELUActivation, NewGELUActivation, QuickGELUActivation]
for name, module in model.named_modules():
for hf_gelu_function in hf_gelu_functions:
if isinstance(module, hf_gelu_function):
rsetattr(model, name, fused_gelu_module)
return model
+4
View File
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
import torch
from efficiency_utils import fuse_gelu
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
@@ -180,6 +181,9 @@ if __name__ == "__main__":
module, "weight", {"optim_bits": 32}
)
if training_conf.fuse_gelu:
model = fuse_gelu(model)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
num_train_epochs=training_conf.num_train_epochs,
+58 -34
View File
@@ -1,6 +1,6 @@
import random
from pathlib import Path
from typing import List
from typing import List, NamedTuple
import evaluate
import transformers
@@ -66,19 +66,64 @@ class PerDatasetSampler(Sampler):
return cls(dataset_sizes, dataset_size_per_epoch)
def get_tokenizer(conf):
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
def get_dataset_fractions(conf, dataset_sizes):
"""Calculate fraction of each dataset to use per epoch when subsampling"""
fractions = []
for i, data_config in enumerate(conf):
dataset_name = get_dataset_name_from_data_config(data_config)
if isinstance(data_config, dict):
if "fraction" in data_config[dataset_name]:
if data_config[dataset_name]["fraction"] <= 0:
raise ValueError("Please specify fraction as a value between 0 < fraction <= 1")
fractions.append(min(1, data_config[dataset_name]["fraction"]))
elif "size" in data_config[dataset_name]:
if data_config[dataset_name]["size"] > dataset_sizes[i]:
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
else:
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
else:
fractions.append(1)
return fractions
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
elif "GPT-JT" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
elif "codegen" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
elif "pythia" in conf.model_name:
tokenizer.add_special_tokens(
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
)
class SpecialTokens(NamedTuple):
pad_token: str = ""
eos_token: str = ""
sep_token: str = ""
class TokenizerConfig(NamedTuple):
special_tokens: SpecialTokens = {}
TOKENIZER_CONFIGS = {
"galactica": TokenizerConfig(special_tokens=SpecialTokens("<pad>", "</s>")),
"GPT-JT": TokenizerConfig(special_tokens=SpecialTokens(sep_token="<|extratoken_100|>")),
"codegen": TokenizerConfig(special_tokens=SpecialTokens("<|endoftext|>", sep_token="<|endoftext|>")),
"pythia": TokenizerConfig(special_tokens=SpecialTokens("<|padding|>", "<|endoftext|>", "<|endoftext|>")),
}
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
"""Match a partial model name to a tokenizer configuration"""
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
if not tokenizer_config_matches:
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")
elif 1 < len(tokenizer_config_matches):
raise ValueError(f"Found multiple tokeniser configuration matches for {model_name=}")
else:
return tokenizer_config_matches[0]
def get_tokenizer(conf) -> transformers.AutoTokenizer:
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
tokenizer_config = match_tokenizer_name(conf.model_name)
if tokenizer_config.special_tokens:
if "GPT-JT" in conf.model_name:
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
additional_special_tokens = (
[]
@@ -171,27 +216,6 @@ def get_dataset_name_from_data_config(data_config):
return data_config
def get_dataset_fractions(conf, dataset_sizes):
"""Calculate fraction of each dataset to use per epoch when subsampling"""
fractions = []
for i, data_config in enumerate(conf):
dataset_name = get_dataset_name_from_data_config(data_config)
if isinstance(data_config, dict):
if "fraction" in data_config[dataset_name]:
if data_config[dataset_name]["fraction"] <= 0:
raise ValueError("Please specify fraction as a value between 0 < fraction <= 1")
fractions.append(min(1, data_config[dataset_name]["fraction"]))
elif "size" in data_config[dataset_name]:
if data_config[dataset_name]["size"] > dataset_sizes[i]:
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
else:
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
else:
fractions.append(1)
return fractions
def get_dataset(conf, tokenizer):
train_datasets, evals = [], {}
@@ -83,7 +83,7 @@ with (open("paragraphs.pkl", "rb")) as openfile:
try:
objects.append(pickle.load(openfile))
except EOFError:
print("Problem laoding your pickle file, using the default array")
print("Problem loading your pickle file, using the default array")
pickle_fail = True
break
@@ -92,7 +92,7 @@ if pickle_fail:
paragraphs = [
"Like for me, this thing is like a little side hobby, but it's also one that's deeply fulfilling. So not just from a business perspective, which is not the way I think about it. I just think from a life human perspective, it's I probably wouldn't have this kind of conversation with you off mic, like this long, this deep, this attentive. There's something really fulfilling about these conversations. So what advice would you have for me? What advice do you have for yourself? Oh, have you not introspected this that deeply? Oh, I have advice. I think the first advice I would give to you is I think you should have me on more often. Yeah. Yeah. That's first and foremost. And second is go on your podcast and have a conversation. Well, I would say you come on my podcast when you're ready. Yeah. When you feel like the product that I'm putting out would benefit from your presence and vice versa, not as a favor to a bro, but at the right time.",
"Well, we really are looking through a two dimensional screen until it's what we intuit to be a three dimensional world and also inferring dynamic stuff, making it 4D. Anyway, is it possible to visualize some pretty pictures that give us a deeper sense of the truth of reality? I think that we will incrementally be able to do that. I think that, for example, the picture that we have of electrons and photons interacting and scattering, it may have not been possible until Faraday did all of his experiments and then Maxwell wrote down his equations. And we were then sort of forced by his equations to think in a new way. And then when Planck in 1900, desperate to try to solve the problem of black body radiation, what they call the ultraviolet catastrophe where Newton was predicting infinite energies where there weren't infinite energies in black body radiation. And he in desperation proposed packets of energy. Then once you've done that, and then you have an Einstein come along five years later and show how that explains the photoelectric effect.",
"But man, I don't know how I would feel about just bacteria everywhere. Well, it would be depressing if it was true. I suppose depressing, I don't think, I don't. I don't know what's more depressing, bacteria everywhere or nothing everywhere. Yes, either of them are chilling. Yeah. But whether it's chilling or not, I don't think should force us to change our view about whether it's real or not. Yes. And what I'm saying may or may not be true. So how would you feel if we discovered life on Mars? Absolutely. It sounds like you would be less excited than some others because you're like, well. What I would be most interested in is how similar to life on Earth it would be. It would actually turn into quite a subtle problem because the likelihood of life having gone to and fro between Mars and the Earth is quite, I wouldn't say high, but it's not low, it's quite feasible. And so if we found life on Mars and it had very similar genetic code, but it was slightly different, most people would interpret that immediately as evidence that they've been transit one way or the other and that it was a common origin of life on Mars or on the Earth and it went one way or the other way.",
"But man, I don't know how I would feel about just bacteria everywhere. Well, it would be depressing if it was true. I suppose depressing, I don't think, I don't. I don't know what's more depressing, bacteria everywhere or nothing everywhere. Yes, either of them are chilling. Yeah. But whether it's chilling or not, I don't think should force us to change our view about whether it's real or not. Yes. And what I'm saying may or may not be true. So how would you feel if we discovered life on Mars? Absolutely. It sounds like you would be less excited than some others because you're like, well. What I would be most interested in is how similar to life on Earth it would be. It would actually turn into quite a subtle problem because the likelihood of life having gone to and from between Mars and the Earth is quite, I wouldn't say high, but it's not low, it's quite feasible. And so if we found life on Mars and it had very similar genetic code, but it was slightly different, most people would interpret that immediately as evidence that they've been transit one way or the other and that it was a common origin of life on Mars or on the Earth and it went one way or the other way.",
]
# Make sure no paragraphs are too long for T5. It handles up to 512 tokens context length.
@@ -86,7 +86,7 @@ Each question and all related answers are on a single line in JSONL format.
#### Table/CSV/Parquet Format
There are a lot more columns left over in the table format. `_q` and `_a` are
suffixes indiciating if the column came from the question or answer table as
suffixes indicating if the column came from the question or answer table as
leftover from a join statement.
```
+1 -1
View File
@@ -15,7 +15,7 @@ trained on
| multilingual | xlm-roberta-base | Multilingual Toxic Comment Classification |
Unbiased and original models also have a 'small' version - but since normal
models are not memory heavy, and small models perform noticably worse, they are
models are not memory heavy, and small models perform noticeably worse, they are
only described in the notebook
## All tests below were ran on a 3090TI
+1 -1
View File
@@ -7,7 +7,7 @@ this project. Please try and follow this structure as closely as possible. While
things will not exactly be the same for each notebook some principles we would
like to try ensure are:
1. Each notebook or collection of related or dependant notebooks should live in
1. Each notebook or collection of related or dependent notebooks should live in
its own folder.
1. Each notebook should have a markdown file with the same name as the notebook
(or README.md if it's a single notebook folder) that explains what the
+8 -4
View File
@@ -68,12 +68,15 @@ class OasstApiClient:
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
logger.debug(f"response: {response}")
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
text = await response.text()
logger.debug(f"resp text: {text}")
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
@@ -114,20 +117,21 @@ class OasstApiClient:
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
lang: Optional[str] = None,
) -> protocol_schema.Task:
"""Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"RESP {resp}")
return self._parse_task(resp)
async def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
) -> protocol_schema.Task:
"""Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
"""Send an ACK for a task to the backend."""
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
SERVER_ERROR0 = 500
SERVER_ERROR1 = 501
INVALID_AUTHENTICATION = 600
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
@@ -80,6 +82,7 @@ class OasstErrorCode(IntEnum):
USER_NOT_SPECIFIED = 4000
USER_DISABLED = 4001
USER_NOT_FOUND = 4002
USER_HAS_NOT_ACCEPTED_TOS = 4003
EMOJI_OP_UNSUPPORTED = 5000
@@ -92,7 +95,7 @@ class OasstError(Exception):
http_status_code: HTTPStatus
def __init__(self, message: str, error_code: OasstErrorCode, http_status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
super().__init__(message, error_code, http_status_code) # make excetpion picklable (fill args member)
super().__init__(message, error_code, http_status_code) # make exception picklable (fill args member)
self.message = message
self.error_code = error_code
self.http_status_code = http_status_code
+13 -2
View File
@@ -13,13 +13,24 @@ class WorkRequest(pydantic.BaseModel):
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
seed: int = pydantic.Field(default_factory=lambda: random.randint(-(2**31), 2**31 - 1))
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**31 - 1))
do_sample: bool = True
top_k: int = 50
top_p: float = 0.9
temperature: float = 1.0
class TokenResponse(pydantic.BaseModel):
text: str
log_prob: float
token_id: int
class GeneratedTextResponse(pydantic.BaseModel):
text: str
class WorkResponsePacket(pydantic.BaseModel):
token: str | None = None
token: TokenResponse | None = None
generated_text: GeneratedTextResponse | None = None
is_end: bool = False
+75 -4
View File
@@ -29,6 +29,17 @@ class User(BaseModel):
auth_method: Literal["discord", "local", "system"]
class Account(BaseModel):
id: UUID
provider: str
provider_account_id: str
class Token(BaseModel):
access_token: str
token_type: str
class FrontEndUser(User):
user_id: UUID
enabled: bool
@@ -39,6 +50,7 @@ class FrontEndUser(User):
streak_days: Optional[int] = None
streak_last_day_date: Optional[datetime] = None
last_activity_date: Optional[datetime] = None
tos_acceptance_date: Optional[datetime] = None
class PageResult(BaseModel):
@@ -64,6 +76,7 @@ class ConversationMessage(BaseModel):
is_assistant: bool
emojis: Optional[dict[str, int]] = None
user_emojis: Optional[list[str]] = None
user_is_author: Optional[bool] = None
class Conversation(BaseModel):
@@ -89,6 +102,12 @@ class Message(ConversationMessage):
created_date: Optional[datetime]
review_result: Optional[bool]
review_count: Optional[int]
deleted: Optional[bool]
synthetic: Optional[bool]
model_name: Optional[str]
message_tree_id: Optional[UUID]
ranking_count: Optional[int]
rank: Optional[int]
class MessagePage(PageResult):
@@ -252,22 +271,21 @@ class AbstractLabelTask(Task):
mode: Optional[LabelTaskMode]
disposition: Optional[LabelTaskDisposition]
labels: Optional[list[LabelDescription]]
conversation: Conversation # the conversation so far (labeling -> last message)
class LabelInitialPromptTask(AbstractLabelTask):
"""A task to label an initial prompt."""
type: Literal["label_initial_prompt"] = "label_initial_prompt"
prompt: str
prompt: str | None = Field(None, deprecated=True, description="deprecated, use `prompt_message`")
class LabelConversationReplyTask(AbstractLabelTask):
"""A task to label a reply to a conversation."""
type: Literal["label_conversation_reply"] = "label_conversation_reply"
conversation: Conversation # the conversation so far (new: including the reply message)
reply_message: Optional[ConversationMessage]
reply: str
reply: str | None = Field(None, deprecated=True, description="deprecated, use last message of `conversation`")
class LabelPrompterReplyTask(LabelConversationReplyTask):
@@ -469,6 +487,47 @@ class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]
class TrollScore(BaseModel):
rank: Optional[int]
user_id: UUID
highlighted: bool = False
username: str
auth_method: str
display_name: str
last_activity_date: Optional[datetime]
troll_score: int = 0
base_date: Optional[datetime]
modified_date: Optional[datetime]
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: Optional[float] = None
humor: Optional[float] = None
toxicity: Optional[float] = None
violence: Optional[float] = None
helpfulness: Optional[float] = None
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
class TrollboardStats(BaseModel):
time_frame: str
last_updated: datetime
trollboard: List[TrollScore]
class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""
@@ -489,6 +548,11 @@ class EmojiCode(str, enum.Enum):
poop = "poop" # 💩
skull = "skull" # 💀
# skip task system uses special emoji codes
skip_reply = "_skip_reply"
skip_ranking = "_skip_ranking"
skip_labeling = "_skip_labeling"
class EmojiOp(str, enum.Enum):
togggle = "toggle"
@@ -500,3 +564,10 @@ class MessageEmojiRequest(BaseModel):
user: User
op: EmojiOp = EmojiOp.togggle
emoji: EmojiCode
class CreateFrontendUserRequest(User):
show_on_leaderboard: bool = True
enabled: bool = True
tos_acceptance: Optional[bool] = None
notes: Optional[str] = None
+108
View File
@@ -0,0 +1,108 @@
# **Datasets**
This folder contains datasets loading scripts that are used to train
OpenAssistant. The current list of datasets can be found
[here](https://docs.google.com/spreadsheets/d/1NYYa6vHiRnk5kwnyYaCT0cBO62--Tm3w4ihdBtp4ISk).
## **Adding a New Dataset**
To add a new dataset to OpenAssistant, follow these steps:
1. **Create an issue**: Create a new
[issue](https://github.com/LAION-AI/Open-Assistant/issues/new) and describe
your proposal for the new dataset.
2. **Create a dataset on HuggingFace**: Create a dataset on
[HuggingFace](https://huggingface.co). See
[below](#creating-a-dataset-on-huggingface) for more details.
3. **Make a pull request**: Add a new dataset loading script to this folder and
link the issue in the pull request description. For more information, see
[below](#making-a-pull-request).
## **Creating a Dataset on HuggingFace**
To create a new dataset on HuggingFace, follow these steps:
#### 1. Convert your dataset file(s) to the Parquet format using the [pandas](https://pandas.pydata.org/) library:
```python
import pandas as pd
# Create a pandas dataframe from your dataset file(s)
df = pd.read_json(...) # or any other way
# Save the file in the Parquet format
df.to_parquet("dataset.parquet", row_group_size=100, engine="pyarrow")
```
#### 2. Install HuggingFace CLI
```bash
pip install huggingface-cli
```
#### 3. Log in to HuggingFace
Use your [access token](https://huggingface.co/docs/hub/security-tokens) to
login:
- Via terminal
```bash
huggingface-cli login
```
- in Jupyter notebook
```python
from huggingface_hub import notebook_login
notebook_login()
```
#### 4. Push the Parquet file to HuggingFace using the following code:
```python
from datasets import Dataset
ds = Dataset.from_parquet("dataset.parquet")
ds.push_to_hub("your_huggingface_name/dataset_name")
```
#### 5. Update the `README.md` file
Update the `README.md` file of your dataset by visiting this link:
https://huggingface.co/datasets/your_huggingface_name/dataset_name/edit/main/README.md
(paste your HuggingFace name and dataset)
## **Making a Pull Request**
#### 1. Fork this repository
#### 2. Create a new branch in your fork
#### 3. Add your dataset to the repository
- Create a folder with the name of your dataset.
- Add a loading script that loads your dataset from HuggingFace, for example:
```python
from datasets import load_dataset
if __name__ == "__main__":
ds = load_dataset("your_huggingface_name/dataset_name")
print(ds["train"])
```
- Optionally, add any other files that describe your dataset and its creation,
such as a README, notebooks, scrapers, etc.
#### 4. Stage your changes and run the pre-commit hook
```bash
pre-commit run
```
#### 5. Submit a pull request
- Submit a pull request and include a link to the issue it resolves in the
description, for example: `Resolves #123`
+18
View File
@@ -0,0 +1,18 @@
#!/usr/bin/env bash
parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../backend"
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True
export DEBUG_ALLOW_SELF_RANKING=True
export DEBUG_ALLOW_DUPLICATE_TASKS=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
export RATE_LIMIT=0
export DEBUG_USE_SEED_DATA_PATH=$parent_path/../../backend/test_data/generic/test_generic_data.json
uvicorn main:app --reload --port 8080 --host 0.0.0.0
popd
+1
View File
@@ -7,6 +7,7 @@ pushd "$parent_path/../../backend"
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True
export DEBUG_ALLOW_SELF_RANKING=True
export DEBUG_ALLOW_DUPLICATE_TASKS=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
+2 -2
View File
@@ -57,7 +57,7 @@ conversation, or at least as a prompt with replies.
guarantee of the quality of the tweets.
- The tweet quality is the other major issue. We can get conversations through
the currently made scripts, but they most likely don't match a useful
instruction -> fulfilment. We are trying to filter the tweets through various
instruction -> fulfillment. We are trying to filter the tweets through various
means such as matching useful hashtags, or by using cosine similarity against
known instructions.
- The modern Twitter API has conversation_id as a field which can be a way to
@@ -68,7 +68,7 @@ conversation, or at least as a prompt with replies.
## TODO
- Write scripts to filter existing conversations into useful instructions ->
fulfilment with hashtags or cosine similarity.
fulfillment with hashtags or cosine similarity.
- Train model to detect if text is a suitable instruction. This could then be
run through the conversations (or full tweet dump) to simplify the process.
Related to issue #143.
@@ -9,7 +9,7 @@
# This assumes data downloaded from https://archive.org/details/twitterstream
# and that the internal .tar files are extracted locally.
# They are large files so using something like 7Zip or WinRar migth be easier
# They are large files so using something like 7Zip or WinRar might be easier
# than putting all of it in scripts, but it is a possibility.
# I often work in notebooks. If you encounter any issue, please reach out to let me know.
+2 -2
View File
@@ -94,7 +94,7 @@ class EssayReviser(DataAugmenter):
def parse_single(self, essay):
instructions = []
# Make stucture error (shuffle one paragraph with another)
# Make structure error (shuffle one paragraph with another)
essay_paragraphs = essay.split("\n\n") # Splitting a String by newline character (\n)
rand1 = random.randint(0, len(essay_paragraphs) - 1)
@@ -424,7 +424,7 @@ class CodeInstructor(DataAugmenter):
def recognize_entities(text, model, n=4, person="ignore"):
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
"""Given a text and a model for entity recognition, return the most occurring entities in the text as a string"""
doc = model(text)
if person == "ignore":
ents = Counter([ent.text.strip() for ent in list(doc.ents) if len(ent.text.strip()) >= 5])
@@ -0,0 +1,56 @@
from collections import defaultdict
from glob import glob
from json import load
from os import path
ALL_PATH = "../../website/public/locales/**/*.json"
DIR = path.dirname(__file__)
EN_PATH = "../../website/public/locales/en/*.json"
def get_not_translated(en_json, translation_json, parent_key=None):
not_translated = []
for key in en_json.keys():
if key in translation_json and translation_json[key] == en_json[key]:
not_translated.append(("{0}.{1}".format(parent_key, key) if parent_key else key))
elif isinstance(en_json[key], dict):
not_translated.extend(get_not_translated(en_json[key], translation_json[key], key))
return not_translated
def get_missing(en_json, translation_json):
return [key for key in en_json.keys() if key not in translation_json]
def print_result(missing, not_translated, file):
if len(missing):
print("[{0}] - {1}\tmissing: {2}".format(path.basename(path.dirname(file)), path.basename(file), missing))
if len(not_translated):
print(
"[{0}] - {1}\tpotentially untranslated: {2}".format(
path.basename(path.dirname(file)), path.basename(file), not_translated
)
)
def audit(file, en_file):
en_json = load(open(en_file))
translation_json = load(open(file))
return (get_missing(en_json, translation_json), get_not_translated(en_json, translation_json), file)
def main():
per_language_dict = defaultdict(list)
for en_file in glob(path.join(DIR, EN_PATH)):
for file in glob(path.join(DIR, ALL_PATH)):
if path.basename(en_file) == path.basename(file) and file != en_file:
file_info = audit(file, en_file)
lang = path.basename(path.dirname(file))
per_language_dict[lang].append(file_info)
for results in per_language_dict.values():
list(map(lambda args: print_result(*args), results))
print()
if __name__ == "__main__":
main()
+1 -1
View File
@@ -66,7 +66,7 @@ def get_winner(pairs):
def get_ranking(pairs):
"""
Abuses concordance property to get a (not necessarily unqiue) ranking.
Abuses concordance property to get a (not necessarily unique) ranking.
The lack of uniqueness is due to the potential existence of multiple
equally ranked winners. We have to pick one, which is where
the non-uniqueness comes from
+3 -3
View File
@@ -58,7 +58,7 @@ def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Vote
after that voter cast a vote on a question.
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information
The consensus is the array of all votes cast by all voters for that question
We then update the voter data using the new information
@@ -88,7 +88,7 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
This function returns the gain of points for a given prompt's votes
In contrast to the other score updating functions, we can run this online as new votes come in.
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information.
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information.
Parameters:
@@ -122,7 +122,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr
This function returns the gain of points for a given ranking's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
i.e. the question has had sufficiently many votes, or we can't get more than "K" bits of information
we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings
we use this over spearman correlation since:
+1 -1
View File
@@ -56,7 +56,7 @@ def next_answer_task(possible_prompts, answers_per_prompt):
This helps to not have too much close-to-finished prompts in the active set.
Parameters:
possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have.
possible_prompts (dict[prompt_id, num_answers]): a dictionary containing all open prompts and the number of answers these prompts currently have.
answers_per_prompt (int): number of answers we per prompt to target
Returns:
prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer
+10
View File
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""Simple REPL frontend."""
# make sure dummy user has accepted the terms of service
create_user_request = dict(USER)
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
+204 -189
View File
@@ -6,12 +6,10 @@ from uuid import uuid4
import requests
import typer
from faker import Faker
app = typer.Typer()
# debug constants
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
fake = Faker()
def _random_message_id():
@@ -26,7 +24,9 @@ def _render_message(message: dict) -> str:
@app.command()
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
def main(
backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234", random_users: int = 1, task_per_user: int = 10
):
"""automates tasks"""
def _post(path: str, json: dict) -> dict:
@@ -50,204 +50,219 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
print(shuffled)
return ranks
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
q = 0
while tasks:
task = tasks.pop(0)
print(task)
for i in range(int(random_users)):
name = fake.name()
USER = {"id": name, "display_name": name, "auth_method": "local"}
match (task["type"]):
case "initial_prompt":
typer.echo("Please provide an initial prompt to the assistant.")
if task["hint"]:
typer.echo(f"Hint: {task['hint']}")
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
create_user_request = dict(USER)
# make sure dummy user has accepted the terms of service
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
q = 0
prompt = gen_random_text()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": prompt,
"user": USER,
},
)
tasks.append(new_task)
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
case "label_initial_prompt":
typer.echo("Label the following prompt:")
typer.echo(task["prompt"])
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
while tasks:
task = tasks.pop(0)
print(task)
valid_labels = task["valid_labels"]
mandatory_labels = task["mandatory_labels"]
match (task["type"]):
case "initial_prompt":
typer.echo("Please provide an initial prompt to the assistant.")
if task["hint"]:
typer.echo(f"Hint: {task['hint']}")
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
for l in mandatory_labels:
if l not in labels:
labels.append(l)
labels_dict = {label: random.random() for label in valid_labels}
if random.random() < 0.9:
labels_dict["spam"] = 0
labels_dict["lang_mismatch"] = 0
prompt = gen_random_text()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": prompt,
"user": USER,
},
)
tasks.append(new_task)
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "prompter_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
case "label_initial_prompt":
typer.echo("Label the following prompt:")
typer.echo(task["prompt"])
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
case "assistant_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
valid_labels = task["valid_labels"]
mandatory_labels = task["mandatory_labels"]
case "rank_prompter_replies" | "rank_assistant_replies":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["replies"])
print(ranking)
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"task_id": task["id"],
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
for l in mandatory_labels:
if l not in labels:
labels.append(l)
labels_dict = {label: random.random() for label in valid_labels}
if random.random() < 0.9:
labels_dict["spam"] = 0
labels_dict["lang_mismatch"] = 0
case "rank_initial_prompts":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["prompots"])
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "prompter_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
case "label_prompter_reply" | "label_assistant_reply":
# acknowledge task
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
case "assistant_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
typer.echo("Label the following reply:")
typer.echo(task["reply"])
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
mandatory_labels = task["mandatory_labels"]
case "rank_prompter_replies" | "rank_assistant_replies":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["replies"])
print(ranking)
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"task_id": task["id"],
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
for l in mandatory_labels:
if l not in labels:
labels.append(l)
labels_dict = {label: random.random() for label in valid_labels}
if random.random() < 0.9:
labels_dict["spam"] = 0
labels_dict["lang_mismatch"] = 0
case "rank_initial_prompts":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["prompots"])
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "task_done":
typer.echo("Task done!")
# rerun with new task slected from above cases
# add a new task
q += 1
if q == 10:
case "label_prompter_reply" | "label_assistant_reply":
# acknowledge task
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
typer.echo("Label the following reply:")
typer.echo(task["reply"])
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
mandatory_labels = task["mandatory_labels"]
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
for l in mandatory_labels:
if l not in labels:
labels.append(l)
labels_dict = {label: random.random() for label in valid_labels}
if random.random() < 0.9:
labels_dict["spam"] = 0
labels_dict["lang_mismatch"] = 0
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "task_done":
typer.echo("Task done!")
break
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
#
case _:
typer.echo(f"Unknown task type {task['type']}")
# rerun with new task slected from above cases
# rerun with new task selected from above cases
# add a new task
q += 1
if q == task_per_user:
typer.echo("Task done!")
break
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
#
case _:
typer.echo(f"Unknown task type {task['type']}")
# rerun with new task selected from above cases
if __name__ == "__main__":
+1
View File
@@ -1,2 +1,3 @@
faker==16.6.1
requests==2.28.1
typer==0.7.0

Some files were not shown because too many files have changed in this diff Show More