diff --git a/.github/workflows/deploy-to-node.yaml b/.github/workflows/deploy-to-node.yaml index 55736b42..182c357d 100644 --- a/.github/workflows/deploy-to-node.yaml +++ b/.github/workflows/deploy-to-node.yaml @@ -33,14 +33,6 @@ jobs: WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }} WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }} WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} - WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY: - ${{ secrets.DEV_WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }} - WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY: - ${{ secrets.DEV_WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY }} - WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA: - ${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }} - WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN: - ${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }} S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} S3_REGION: ${{ secrets.S3_REGION }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} @@ -49,11 +41,20 @@ jobs: MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }} MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }} GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }} + MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }} SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }} STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }} STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }} STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }} STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }} + WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY: + ${{ secrets.WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }} + WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY: + ${{ secrets.WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY }} + WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA: + ${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }} + WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN: + ${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }} steps: - name: Checkout uses: actions/checkout@v2 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39773b41..c0d77f69 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -140,4 +140,4 @@ automatically deploy the built release to the dev machine. ### Contribute a Dataset See -[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md) +[here](https://github.com/LAION-AI/Open-Assistant/blob/main/openassistant/datasets/README.md) diff --git a/ansible/deploy-to-node.yaml b/ansible/deploy-to-node.yaml index 3238228e..81428988 100644 --- a/ansible/deploy-to-node.yaml +++ b/ansible/deploy-to-node.yaml @@ -122,6 +122,9 @@ TREE_MANAGER__MAX_CHILDREN_COUNT: "{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') | default('3', true) }}" + MESSAGE_SIZE_LIMIT: + "{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') | + default('2000', true) }}" USER_STATS_INTERVAL_DAY: "{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') | default('5', true) }}" @@ -175,9 +178,9 @@ NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY: "{{ lookup('ansible.builtin.env', 'WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY') }}" - CLOUDFLARE_CAPTCHA_SERCERT_KEY: + CLOUDFLARE_CAPTCHA_SECRET_KEY: "{{ lookup('ansible.builtin.env', - 'WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY') }}" + 'WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY') }}" NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA: "{{ lookup('ansible.builtin.env', 'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA') }}" diff --git a/backend/alembic/versions/2023_01_29_1207-7b8f0011e0b0_move_user_streak_from_user_stats_to_.py b/backend/alembic/versions/2023_01_29_1207-7b8f0011e0b0_move_user_streak_from_user_stats_to_.py index 9cf3a233..9117ad52 100644 --- a/backend/alembic/versions/2023_01_29_1207-7b8f0011e0b0_move_user_streak_from_user_stats_to_.py +++ b/backend/alembic/versions/2023_01_29_1207-7b8f0011e0b0_move_user_streak_from_user_stats_to_.py @@ -36,6 +36,7 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("user_stats", "streak_days") - op.drop_column("user_stats", "streak_last_day_date") + op.drop_column("user", "streak_days") + op.drop_column("user", "streak_last_day_date") + op.drop_column("user", "last_activity_date") # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py b/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py new file mode 100644 index 00000000..bca17b4f --- /dev/null +++ b/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py @@ -0,0 +1,34 @@ +"""add tos_acceptance_date to user + +Revision ID: 55361f323d12 +Revises: 7b8f0011e0b0 +Create Date: 2023-02-01 00:22:08.280251 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "55361f323d12" +down_revision = "f60958968ff8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True)) + op.drop_column("user_stats", "streak_days") + op.drop_column("user_stats", "streak_last_day_date") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True) + ) + op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True)) + op.drop_column("user", "tos_acceptance_date") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_02_01_1010-f60958968ff8_add_won_prompt_lottery_date_to_mts.py b/backend/alembic/versions/2023_02_01_1010-f60958968ff8_add_won_prompt_lottery_date_to_mts.py new file mode 100644 index 00000000..f82f28fd --- /dev/null +++ b/backend/alembic/versions/2023_02_01_1010-f60958968ff8_add_won_prompt_lottery_date_to_mts.py @@ -0,0 +1,27 @@ +"""add won_prompt_lottery_date to mts + +Revision ID: f60958968ff8 +Revises: 7b8f0011e0b0 +Create Date: 2023-02-01 10:10:38.301707 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f60958968ff8" +down_revision = "7b8f0011e0b0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message_tree_state", sa.Column("won_prompt_lottery_date", sa.DateTime(timezone=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_tree_state", "won_prompt_lottery_date") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_02_01_2146-9e7ec4a9e3f2_add_skip_bool_skip_reason_to_task.py b/backend/alembic/versions/2023_02_01_2146-9e7ec4a9e3f2_add_skip_bool_skip_reason_to_task.py new file mode 100644 index 00000000..92c36adc --- /dev/null +++ b/backend/alembic/versions/2023_02_01_2146-9e7ec4a9e3f2_add_skip_bool_skip_reason_to_task.py @@ -0,0 +1,30 @@ +"""add skip bool & skip_reason to task + +Revision ID: 9e7ec4a9e3f2 +Revises: 7b8f0011e0b0 +Create Date: 2023-02-01 21:46:49.971052 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9e7ec4a9e3f2" +down_revision = "55361f323d12" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("task", sa.Column("skipped", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + op.add_column("task", sa.Column("skip_reason", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("task", "skip_reason") + op.drop_column("task", "skipped") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py b/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py new file mode 100644 index 00000000..aa9b1ffe --- /dev/null +++ b/backend/alembic/versions/2023_02_02_1544-4d7e0b0ebe84_add_troll_stats.py @@ -0,0 +1,59 @@ +"""add troll_stats + +Revision ID: 4d7e0b0ebe84 +Revises: 9e7ec4a9e3f2 +Create Date: 2023-02-02 15:44:12.647260 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "4d7e0b0ebe84" +down_revision = "9e7ec4a9e3f2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "troll_stats", + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("base_date", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False + ), + sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("troll_score", sa.Integer(), nullable=False), + sa.Column("rank", sa.Integer(), nullable=True), + sa.Column("red_flags", sa.Integer(), nullable=False), + sa.Column("upvotes", sa.Integer(), nullable=False), + sa.Column("downvotes", sa.Integer(), nullable=False), + sa.Column("spam_prompts", sa.Integer(), nullable=False), + sa.Column("quality", sa.Float(), nullable=True), + sa.Column("humor", sa.Float(), nullable=True), + sa.Column("toxicity", sa.Float(), nullable=True), + sa.Column("violence", sa.Float(), nullable=True), + sa.Column("helpfulness", sa.Float(), nullable=True), + sa.Column("spam", sa.Integer(), nullable=False), + sa.Column("lang_mismach", sa.Integer(), nullable=False), + sa.Column("not_appropriate", sa.Integer(), nullable=False), + sa.Column("pii", sa.Integer(), nullable=False), + sa.Column("hate_speech", sa.Integer(), nullable=False), + sa.Column("sexual_content", sa.Integer(), nullable=False), + sa.Column("political_content", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("user_id", "time_frame"), + ) + op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats") + op.drop_table("troll_stats") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py b/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py new file mode 100644 index 00000000..3ab59b08 --- /dev/null +++ b/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py @@ -0,0 +1,39 @@ +"""Add Account table + +Revision ID: 8c8241d1f973 +Revises: 4d7e0b0ebe84 +Create Date: 2023-01-30 15:10:58.776315 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "8c8241d1f973" +down_revision = "4d7e0b0ebe84" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "account", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("provider", "account", ["provider_account_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("provider", table_name="account") + op.drop_table("account") + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index 07d0b45b..8e30b78e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA: ur = UserRepository(db=session, api_client=api_client) tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur) + ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True) pr = PromptRepository( db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr ) diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 003f039f..331a7841 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -10,6 +10,7 @@ from oasst_backend.api.v1 import ( stats, tasks, text_labels, + trollboards, users, ) @@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"]) +api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"]) api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index a4ca6380..114a3a9c 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -59,6 +59,37 @@ def query_frontend_user( return user.to_protocol_frontend_user() +@router.post("/", response_model=protocol.FrontEndUser) +def create_frontend_user( + *, + create_user: protocol.CreateFrontendUserRequest, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + user = ur.lookup_client_user(create_user, create_missing=True) + + def changed(a, b) -> bool: + return a is not None and a != b + + # only call update_user if something changed + if ( + changed(create_user.enabled, user.enabled) + or changed(create_user.show_on_leaderboard, user.show_on_leaderboard) + or changed(create_user.notes, user.notes) + or (create_user.tos_acceptance and user.tos_acceptance_date is None) + ): + user = ur.update_user( + user.id, + enabled=create_user.enabled, + show_on_leaderboard=create_user.show_on_leaderboard, + tos_acceptance=create_user.tos_acceptance, + notes=create_user.notes, + ) + + return user.to_protocol_frontend_user() + + @router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message]) def query_frontend_user_messages( auth_method: str, diff --git a/backend/oasst_backend/api/v1/login.py b/backend/oasst_backend/api/v1/login.py new file mode 100644 index 00000000..8aab5328 --- /dev/null +++ b/backend/oasst_backend/api/v1/login.py @@ -0,0 +1,73 @@ +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request +from oasst_backend import auth +from oasst_backend.api import deps +from oasst_backend.config import Settings +from oasst_backend.models import Account +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_401_UNAUTHORIZED + +router = APIRouter() + + +@router.get("/discord") +def login_discord(request: Request): + redirect_uri = f"{get_callback_uri(request)}/discord" + auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify" + raise HTTPException(status_code=302, headers={"location": auth_url}) + + +@router.get("/callback/discord", response_model=protocol_schema.Token) +async def callback_discord( + auth_code: str, + request: Request, + db: Session = Depends(deps.get_db), +): + redirect_uri = f"{get_callback_uri(request)}/discord" + + async with aiohttp.ClientSession(raise_for_status=True) as session: + # Exchange the auth code for a Discord access token + async with session.post( + "https://discord.com/api/oauth2/token", + data={ + "client_id": Settings.AUTH_DISCORD_CLIENT_ID, + "client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": redirect_uri, + "scope": "identify", + }, + ) as token_response: + token_response_json = await token_response.json() + access_token = token_response_json["access_token"] + + # Retrieve user's Discord information using access token + async with session.get( + "https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"} + ) as user_response: + user_response_json = await user_response.json() + discord_id = user_response_json["id"] + + account: Account = auth.get_account_from_discord_id(db, discord_id) + + if not account: + # Discord account is not linked to an OA account + raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED) + + # Discord account is valid and linked to an OA account -> create JWT + access_token = auth.create_access_token(account) + + return protocol_schema.Token(access_token=access_token, token_type="bearer") + + +def get_callback_uri(request: Request): + """ + Gets the URI for the base callback endpoint with no provider name appended. + """ + # This seems ugly, not sure if there is a better way + current_url = str(request.url) + domain = current_url.split("/api/v1/")[0] + redirect_uri = f"{domain}/api/v1/callback" + return redirect_uri diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 7db42662..c817ba67 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -131,7 +131,7 @@ def tasks_acknowledge_failure( logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.") api_client = deps.api_auth(api_key, db) pr = PromptRepository(db, api_client, frontend_user=frontend_user) - pr.task_repository.acknowledge_task_failure(task_id) + pr.skip_task(task_id=task_id, reason=nack_request.reason) except (KeyError, RuntimeError): logger.exception("Failed to not acknowledge task.") raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED) diff --git a/backend/oasst_backend/api/v1/trollboards.py b/backend/oasst_backend/api/v1/trollboards.py new file mode 100644 index 00000000..4ba5c256 --- /dev/null +++ b/backend/oasst_backend/api/v1/trollboards.py @@ -0,0 +1,21 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, Query +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame +from oasst_shared.schemas.protocol import TrollboardStats +from sqlmodel import Session + +router = APIRouter() + + +@router.get("/{time_frame}", response_model=TrollboardStats) +def get_trollboard( + time_frame: UserStatsTimeFrame, + max_count: Optional[int] = Query(100, gt=0, le=10000), + api_client: ApiClient = Depends(deps.get_trusted_api_client), + db: Session = Depends(deps.get_db), +) -> TrollboardStats: + usr = UserStatsRepository(db) + return usr.get_trollboard(time_frame, limit=max_count) diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 2ced40c1..b3604c3f 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -191,6 +191,7 @@ def update_user( enabled: Optional[bool] = None, notes: Optional[str] = None, show_on_leaderboard: Optional[bool] = None, + tos_acceptance: Optional[bool] = None, db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), ): @@ -198,7 +199,7 @@ def update_user( Update a user by global user ID. Only trusted clients can update users. """ ur = UserRepository(db, api_client) - ur.update_user(user_id, enabled, notes, show_on_leaderboard) + ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance) @router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT) diff --git a/backend/oasst_backend/auth.py b/backend/oasst_backend/auth.py new file mode 100644 index 00000000..2c633fa4 --- /dev/null +++ b/backend/oasst_backend/auth.py @@ -0,0 +1,37 @@ +from datetime import datetime, timedelta +from typing import Optional + +from jose import jwt +from oasst_backend.config import Settings +from oasst_backend.models import Account +from sqlmodel import Session + + +def create_access_token(data: dict) -> str: + """ + Create an encoded JSON Web Token (JWT) using the given data. + """ + + expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM) + return encoded_jwt + + +def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]: + """ + Get the Open-Assistant Account associated with the given Discord ID. + """ + + account: Account = ( + db.query(Account) + .filter( + Account.provider == "discord", + Account.provider_account_id == discord_id, + ) + .first() + ) + + return account diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 449543d5..851f9ace 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -13,7 +13,7 @@ class TreeManagerConfiguration(BaseModel): No new initial prompt tasks are handed out to users if this number is reached.""" - max_tree_depth: int = 6 + max_tree_depth: int = 3 """Maximum depth of message tree.""" max_children_count: int = 3 @@ -22,7 +22,7 @@ class TreeManagerConfiguration(BaseModel): num_prompter_replies: int = 1 """Number of prompter replies to collect per assistant reply.""" - goal_tree_size: int = 15 + goal_tree_size: int = 12 """Total number of messages to gather per tree.""" num_reviews_initial_prompt: int = 3 @@ -135,6 +135,11 @@ class Settings(BaseSettings): AUTH_LENGTH: int = 32 AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=" AUTH_COOKIE_NAME: str = "next-auth.session-token" + AUTH_ALGORITHM: str = "HS256" + AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + + AUTH_DISCORD_CLIENT_ID: str = "" + AUTH_DISCORD_CLIENT_SECRET: str = "" POSTGRES_HOST: str = "localhost" POSTGRES_PORT: str = "5432" @@ -158,6 +163,9 @@ class Settings(BaseSettings): DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False DEBUG_SKIP_TOXICITY_CALCULATION: bool = False DEBUG_DATABASE_ECHO: bool = False + DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS + True # TODO: set False after ToS acceptance UI was added to web-frontend + ) DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120 diff --git a/backend/oasst_backend/models/__init__.py b/backend/oasst_backend/models/__init__.py index 420c0ccd..65594dde 100644 --- a/backend/oasst_backend/models/__init__.py +++ b/backend/oasst_backend/models/__init__.py @@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity from .message_tree_state import MessageTreeState from .task import Task from .text_labels import TextLabels +from .troll_stats import TrollStats from .user import User from .user_stats import UserStats, UserStatsTimeFrame @@ -26,4 +27,5 @@ __all__ = [ "Journal", "JournalIntegration", "MessageEmoji", + "TrollStats", ] diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index a286d483..199f475b 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -1,4 +1,6 @@ +from datetime import datetime from enum import Enum +from typing import Optional from uuid import UUID import sqlalchemy as sa @@ -46,6 +48,9 @@ class State(str, Enum): BACKLOG_RANKING = "backlog_ranking" """Imported tree ready to be activated and ranked by users (currently inactive).""" + PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting" + """Initial prompt has passed spam check, waiting to be drawn to grow.""" + VALID_STATES = ( State.INITIAL_PROMPT_REVIEW, @@ -63,6 +68,7 @@ TERMINAL_STATES = ( State.SCORING_FAILED, State.HALTED_BY_MODERATOR, State.BACKLOG_RANKING, + State.PROMPT_LOTTERY_WAITING, ) @@ -78,3 +84,4 @@ class MessageTreeState(SQLModel, table=True): state: str = Field(nullable=False, max_length=128, index=True) active: bool = Field(nullable=False, index=True) origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True)) + won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index 7f91b157..ad41c78a 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -31,6 +31,8 @@ class Task(SQLModel, table=True): api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") ack: Optional[bool] = None done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + skipped: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + skip_reason: str = Field(nullable=True, max_length=512) frontend_message_id: Optional[str] = None message_tree_id: Optional[UUID] = None parent_message_id: Optional[UUID] = None diff --git a/backend/oasst_backend/models/troll_stats.py b/backend/oasst_backend/models/troll_stats.py new file mode 100644 index 00000000..2cef7246 --- /dev/null +++ b/backend/oasst_backend/models/troll_stats.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from sqlmodel import Field, Index, SQLModel + + +class TrollStats(SQLModel, table=True): + __tablename__ = "troll_stats" + __table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),) + + time_frame: Optional[str] = Field(nullable=False, primary_key=True) + user_id: Optional[UUID] = Field( + sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) + ) + base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) + + troll_score: int = 0 + modified_date: Optional[datetime] = Field( + sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) + ) + + rank: int = Field(nullable=True) + + red_flags: int = 0 # num reported messages of user + upvotes: int = 0 # num up-voted messages of user + downvotes: int = 0 # num down-voted messages of user + + spam_prompts: int = 0 + + quality: float = Field(nullable=True) + humor: float = Field(nullable=True) + toxicity: float = Field(nullable=True) + violence: float = Field(nullable=True) + helpfulness: float = Field(nullable=True) + + spam: int = 0 + lang_mismach: int = 0 + not_appropriate: int = 0 + pii: int = 0 + hate_speech: int = 0 + sexual_content: int = 0 + political_content: int = 0 + + def compute_troll_score(self) -> int: + return ( + self.red_flags * 3 + - self.upvotes + + self.downvotes + + self.spam_prompts + + self.lang_mismach + + self.not_appropriate + + self.pii + + self.hate_speech + + self.sexual_content + + self.political_content + ) diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 3d3bd6a9..59239ccf 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -41,6 +41,9 @@ class User(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp()) ) + # terms of service acceptance date + tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) + def to_protocol_frontend_user(self): return protocol.FrontEndUser( user_id=self.id, @@ -55,4 +58,19 @@ class User(SQLModel, table=True): streak_days=self.streak_days, streak_last_day_date=self.streak_last_day_date, last_activity_date=self.last_activity_date, + tos_acceptance_date=self.tos_acceptance_date, ) + + +class Account(SQLModel, table=True): + __tablename__ = "account" + __table_args__ = (Index("provider", "provider_account_id", unique=True),) + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + user_id: UUID = Field(foreign_key="user.id") + provider: str = Field(nullable=False, max_length=128, default="email") # discord or email + provider_account_id: str = Field(nullable=False, max_length=128) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index e889e73b..dacd5f9b 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -35,7 +35,21 @@ from oasst_shared.utils import unaware_to_utc, utcnow from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update -from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND + +_task_type_and_reaction = ( + ( + (db_payload.PrompterReplyPayload, db_payload.AssistantReplyPayload), + protocol_schema.EmojiCode.skip_reply, + ), + ( + (db_payload.LabelInitialPromptPayload, db_payload.LabelConversationReplyPayload), + protocol_schema.EmojiCode.skip_labeling, + ), + ( + (db_payload.RankInitialPromptsPayload, db_payload.RankConversationRepliesPayload), + protocol_schema.EmojiCode.skip_ranking, + ), +) class PromptRepository: @@ -77,7 +91,14 @@ class PromptRepository: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) if self.user.deleted or not self.user.enabled: - raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED) + raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE) + + if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE: + raise OasstError( + "User has not accepted terms of service.", + OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS, + HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS, + ) def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message: validate_frontend_message_id(frontend_message_id) @@ -90,7 +111,7 @@ class PromptRepository: raise OasstError( f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND, - HTTP_404_NOT_FOUND, + HTTPStatus.NOT_FOUND, ) return message @@ -139,7 +160,12 @@ class PromptRepository: return message def _validate_task( - self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None + self, + task: Task, + *, + task_id: Optional[UUID] = None, + frontend_message_id: Optional[str] = None, + check_ack: bool = True, ) -> Task: if task is None: if task_id: @@ -150,7 +176,7 @@ class PromptRepository: if task.expired: raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if not task.ack: + if check_ack and not task.ack: raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK) if task.done: raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE) @@ -675,7 +701,7 @@ class PromptRepository: message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: - raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) + raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND) return message def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]: @@ -874,7 +900,7 @@ class PromptRepository: if api_client_id != self.api_client.id: # Unprivileged api client asks for foreign messages - raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN) qry = self.db.query(Message) if user_id: @@ -995,7 +1021,31 @@ WHERE message.id = cc.id; message_trees=result.get(None, 0), ) - def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message: + @managed_tx_method() + def skip_task(self, task_id: UUID, reason: str): + self.ensure_user_is_enabled() + + task = self.task_repository.fetch_task_by_id(task_id) + self._validate_task(task, check_ack=False) + + if not task.collective: + task.skipped = True + task.skip_reason = reason + self.db.add(task) + + def handle_cancel_emoji(task_payload: db_payload.TaskPayload) -> Message | None: + for types, emoji in _task_type_and_reaction: + for t in types: + if isinstance(task_payload, t): + return self.handle_message_emoji(task.parent_message_id, protocol_schema.EmojiOp.add, emoji) + return None + + task_payload: db_payload.TaskPayload = task.payload.payload + handle_cancel_emoji(task_payload) + + def handle_message_emoji( + self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema.EmojiCode + ) -> Message: self.ensure_user_is_enabled() message = self.fetch_message(message_id) diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index db4ba576..b721caf5 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -167,21 +167,6 @@ class TaskRepository: task.done = True self.db.add(task) - @managed_tx_method(CommitMode.COMMIT) - def acknowledge_task_failure(self, task_id): - # find task - task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() - if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) - if task.expired: - raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) - if task.done or task.ack is not None: - raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED) - - task.ack = False - # ToDo: check race-condition, transaction - self.db.add(task) - @managed_tx_method(CommitMode.COMMIT) def insert_task( self, diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index b1d50d35..e09f22d1 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -13,7 +13,16 @@ from fastapi.encoders import jsonable_encoder from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list from oasst_backend.config import TreeManagerConfiguration, settings -from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state +from oasst_backend.models import ( + Message, + MessageEmoji, + MessageReaction, + MessageTreeState, + Task, + TextLabels, + User, + message_tree_state, +) from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils import tree_export from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method @@ -21,7 +30,8 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session, func, not_, or_, text, update +from oasst_shared.utils import utcnow +from sqlmodel import Session, and_, func, not_, or_, text, update class TaskType(Enum): @@ -153,7 +163,7 @@ class TreeManager: def _determine_task_availability_internal( self, - num_active_trees: int, + num_missing_prompts: int, extendible_parents: list[ExtendibleParentRow], prompts_need_review: list[Message], replies_need_review: list[Message], @@ -161,8 +171,7 @@ class TreeManager: ) -> dict[protocol_schema.TaskRequestType, int]: task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType} - num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees) - task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts + task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(1, num_missing_prompts) task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) @@ -194,6 +203,72 @@ class TreeManager: return task_count_by_type + def _prompt_lottery(self, lang: str) -> int: + MAX_RETRIES = 5 + retry = 0 + while True: + num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) + num_missing_prompts = self.cfg.max_active_trees - num_active_trees + if num_missing_prompts <= 0: + return 0 + + # select among distinct users + authors_qry = ( + self.db.query(Message.user_id) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter( + MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, + Message.lang == lang, + not_(Message.deleted), + Message.review_result, + ) + .distinct(Message.user_id) + ) + + author_ids = authors_qry.all() + if len(author_ids) == 0: + logger.info( + f"No prompts for prompt lottery available ({num_missing_prompts} trees missing for {lang=})." + ) + return num_missing_prompts + + # first select an authour + prompt_author_id: UUID = random.choice(author_ids)["user_id"] + logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.") + + # select random prompt of author + qry = ( + self.db.query(MessageTreeState, Message) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter( + MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING, + Message.user_id == prompt_author_id, + Message.lang == lang, + not_(Message.deleted), + Message.review_result, + ) + .limit(100) + ) + + prompt_candidates = qry.all() + if len(prompt_candidates) == 0: + retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry + if retry < MAX_RETRIES: + continue + else: + logger.warning("Max retries in prompt lottery reached.") + return num_missing_prompts + + winner_prompt = random.choice(prompt_candidates) + message: Message = winner_prompt.Message + logger.info(f"Prompt lottery winner: {message.id=}") + + mts: MessageTreeState = winner_prompt.MessageTreeState + self._enter_state(mts, message_tree_state.State.GROWING) + self.db.flush() + def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]: self.pr.ensure_user_is_enabled() @@ -201,14 +276,14 @@ class TreeManager: lang = "en" logger.warning("Task availability request without lang tag received, assuming lang='en'.") - num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) + num_missing_prompts = self._prompt_lottery(lang=lang) extendible_parents, _ = self.query_extendible_parents(lang=lang) prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) incomplete_rankings = self.query_incomplete_rankings(lang=lang) return self._determine_task_availability_internal( - num_active_trees=num_active_trees, + num_missing_prompts=num_missing_prompts, extendible_parents=extendible_parents, prompts_need_review=prompts_need_review, replies_need_review=replies_need_review, @@ -238,7 +313,8 @@ class TreeManager: lang = "en" logger.warning("Task request without lang tag received, assuming 'en'.") - num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) + num_missing_prompts = self._prompt_lottery(lang=lang) + prompts_need_review = self.query_prompts_need_review(lang=lang) replies_need_review = self.query_replies_need_review(lang=lang) extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang) @@ -256,7 +332,7 @@ class TreeManager: num_ranking_tasks=len(incomplete_rankings), num_replies_need_review=len(replies_need_review), num_prompts_need_review=len(prompts_need_review), - num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), + num_missing_prompts=num_missing_prompts, num_missing_replies=num_missing_replies, ) @@ -268,7 +344,7 @@ class TreeManager: ) else: task_count_by_type = self._determine_task_availability_internal( - num_active_trees=num_active_trees, + num_missing_prompts=num_missing_prompts, extendible_parents=extendible_parents, prompts_need_review=prompts_need_review, replies_need_review=replies_need_review, @@ -611,7 +687,7 @@ class TreeManager: ) else: self.enter_low_grade_state(msg.message_tree_id) - self.check_condition_for_growing_state(msg.message_tree_id) + self.check_condition_for_prompt_lottery(msg.message_tree_id) elif msg.review_count >= self.cfg.num_reviews_reply: if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply: msg.review_result = True @@ -649,6 +725,8 @@ class TreeManager: if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang: self.activate_backlog_tree(lang=root_msg.lang) else: + if mts.state == message_tree_state.State.GROWING and mts.won_prompt_lottery_date is None: + mts.won_prompt_lottery_date = utcnow() logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})") def enter_low_grade_state(self, message_tree_id: UUID) -> None: @@ -656,8 +734,8 @@ class TreeManager: mts = self.pr.fetch_tree_state(message_tree_id) self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE) - def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool: - logger.debug(f"check_condition_for_growing_state({message_tree_id=})") + def check_condition_for_prompt_lottery(self, message_tree_id: UUID) -> bool: + logger.debug(f"check_condition_for_prompt_lottery({message_tree_id=})") mts = self.pr.fetch_tree_state(message_tree_id) if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW: @@ -670,7 +748,7 @@ class TreeManager: logger.debug(f"False {initial_prompt.review_result=}") return False - self._enter_state(mts, message_tree_state.State.GROWING) + self._enter_state(mts, message_tree_state.State.PROMPT_LOTTERY_WAITING) return True def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool: @@ -820,6 +898,14 @@ class TreeManager: self.db.query(Message) .select_from(MessageTreeState) .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .outerjoin( + MessageEmoji, + and_( + Message.id == MessageEmoji.message_id, + MessageEmoji.user_id == self.pr.user_id, + MessageEmoji.emoji == protocol_schema.EmojiCode.skip_labeling, + ), + ) .filter( MessageTreeState.active, MessageTreeState.state == state, @@ -827,6 +913,7 @@ class TreeManager: not_(Message.deleted), Message.review_count < required_reviews, Message.lang == lang, + MessageEmoji.message_id.is_(None), ) ) @@ -879,12 +966,18 @@ SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) chi mts.message_tree_id FROM message_tree_state mts INNER JOIN message m ON mts.message_tree_id = m.message_tree_id + LEFT JOIN message_emoji me on + (m.parent_id = me.message_id + AND :skip_user_id IS NOT NULL + AND me.user_id = :skip_user_id + AND me.emoji = :skip_ranking) WHERE mts.active -- only consider active trees AND mts.state = :ranking_state -- message tree must be in ranking state AND m.review_result -- must be reviewed AND m.lang = :lang -- matches lang AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts + AND me.message_id IS NULL -- no skip ranking emoji for user GROUP BY m.parent_id, m.role, mts.message_tree_id HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ @@ -910,6 +1003,8 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0) "ranking_state": message_tree_state.State.RANKING, "lang": lang, "user_id": user_id, + "skip_user_id": self.pr.user_id, + "skip_ranking": protocol_schema.EmojiCode.skip_ranking, }, ) return [IncompleteRankingsRow.from_orm(x) for x in r.all()] @@ -919,6 +1014,11 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0) SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count FROM message_tree_state mts INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree + LEFT JOIN message_emoji me ON + (m.id = me.message_id + AND :skip_user_id IS NOT NULL + AND me.user_id = :skip_user_id + AND me.emoji = :skip_reply) LEFT JOIN message c ON m.id = c.parent_id -- child nodes WHERE mts.active -- only consider active trees AND mts.state = :growing_state -- message tree must be growing @@ -926,6 +1026,7 @@ WHERE mts.active -- only consider active trees AND m.depth < mts.max_depth -- ignore leaf nodes as parents AND m.review_result -- parent node must have positive review AND m.lang = :lang -- parent matches lang + AND me.message_id IS NULL -- no skip reply emoji for user AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count @@ -946,6 +1047,8 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children "num_prompter_replies": self.cfg.num_prompter_replies, "lang": lang, "user_id": user_id, + "skip_user_id": self.pr.user_id, + "skip_reply": protocol_schema.EmojiCode.skip_reply, }, ) @@ -984,6 +1087,8 @@ HAVING COUNT(m.id) < mts.goal_tree_size "num_prompter_replies": self.cfg.num_prompter_replies, "lang": lang, "user_id": user_id, + "skip_user_id": self.pr.user_id, + "skip_reply": protocol_schema.EmojiCode.skip_reply, }, ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] @@ -1097,7 +1202,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state." ) for t in prompt_review_trees: - self.check_condition_for_growing_state(t.message_tree_id) + self.check_condition_for_prompt_lottery(t.message_tree_id) growing_trees: list[MessageTreeState] = ( self.db.query(MessageTreeState) @@ -1288,11 +1393,17 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id; logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.") def _reactivate_tree(self, mts: MessageTreeState): - self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) + if mts.state == message_tree_state.State.PROMPT_LOTTERY_WAITING: + return + tree_id = mts.message_tree_id - if self.check_condition_for_growing_state(tree_id): + if mts.won_prompt_lottery_date is not None: + self._enter_state(mts, message_tree_state.State.GROWING) if self.check_condition_for_ranking_state(tree_id): self.check_condition_for_scoring_state(tree_id) + else: + self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW) + self.check_condition_for_prompt_lottery(tree_id) @managed_tx_method(CommitMode.FLUSH) def purge_user_messages( @@ -1450,7 +1561,8 @@ if __name__ == "__main__": with Session(engine) as db: api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) # api_client = create_api_client(session=db, description="test", frontend_type="bot") - dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") + # dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") + dummy_user = protocol_schema.User(id="1234", display_name="bulb", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user) cfg = TreeManagerConfiguration() @@ -1465,14 +1577,18 @@ if __name__ == "__main__": # print("query_incomplete_rankings", tm.query_incomplete_rankings()) # print("query_replies_need_review", tm.query_replies_need_review()) # print("query_incomplete_reply_reviews", tm.query_replies_need_review()) - # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) + xs = tm.query_prompts_need_review(lang="en") + print("xs", len(xs)) + for x in xs: + print(x.id, x.emojis) + # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review(lang="en")) # print("query_extendible_trees", tm.query_extendible_trees()) # print("query_extendible_parents", tm.query_extendible_parents()) # print("next_task:", tm.next_task()) - print( - ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b")) - ) + # print( + # ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b")) + # ) # print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl")) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 984964b6..ba6d1a10 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -73,7 +73,8 @@ class UserRepository: enabled: Optional[bool] = None, notes: Optional[str] = None, show_on_leaderboard: Optional[bool] = None, - ) -> None: + tos_acceptance: Optional[bool] = None, + ) -> User: """ Update a user by global user ID to disable or set admin notes. Only trusted clients may update users. @@ -94,8 +95,11 @@ class UserRepository: user.notes = notes if show_on_leaderboard is not None: user.show_on_leaderboard = show_on_leaderboard + if tos_acceptance: + user.tos_acceptance_date = utcnow() self.db.add(user) + return user @managed_tx_method(CommitMode.COMMIT) def mark_user_deleted(self, id: UUID) -> None: @@ -143,8 +147,10 @@ class UserRepository: display_name=display_name, api_client_id=self.api_client.id, auth_method=auth_method, - show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user ) + if auth_method == "system": + user.show_on_leaderboard = False # don't show system users, e.g. import user + user.tos_acceptance_date = utcnow() self.db.add(user) elif display_name and display_name != user.display_name: # we found the user but the display name changed @@ -156,6 +162,10 @@ class UserRepository: def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None: if not client_user: return None + + if not (client_user.auth_method and client_user.id): + raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED) + num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT for i in range(num_retries): try: diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 3862d098..4c28b293 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -5,19 +5,31 @@ from uuid import UUID import sqlalchemy as sa from loguru import logger from oasst_backend.config import settings -from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame +from oasst_backend.models import ( + Message, + MessageReaction, + MessageTreeState, + Task, + TextLabels, + TrollStats, + User, + UserStats, + UserStatsTimeFrame, +) from oasst_backend.models.db_payload import ( LabelAssistantReplyPayload, LabelPrompterReplyPayload, RankingReactionPayload, ) -from oasst_shared.schemas.protocol import LeaderboardStats, UserScore +from oasst_backend.models.message_tree_state import State as TreeState +from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore from oasst_shared.utils import log_timing, utcnow from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.functions import coalesce from sqlmodel import Session, delete, func, text -def _create_user_score(r, highlighted_user_id: UUID | None): +def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore: if r["UserStats"]: d = r["UserStats"].dict() else: @@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None): return UserScore(**d) +def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore: + if r["TrollStats"]: + d = r["TrollStats"].dict() + else: + d = {"modified_date": utcnow()} + for k in [ + "user_id", + "username", + "auth_method", + "display_name", + "last_activity_date", + ]: + d[k] = r[k] + if highlighted_user_id: + d["highlighted"] = r["user_id"] == highlighted_user_id + return TrollScore(**d) + + class UserStatsRepository: def __init__(self, session: Session): self.session = session @@ -133,6 +163,38 @@ class UserStatsRepository: stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame} return stats_by_timeframe + def get_trollboard( + self, + time_frame: UserStatsTimeFrame, + limit: int = 100, + highlighted_user_id: Optional[UUID] = None, + ) -> TrollboardStats: + """ + Get trollboard stats for the specified time frame + """ + + qry = ( + self.session.query( + User.id.label("user_id"), + User.username, + User.auth_method, + User.display_name, + User.last_activity_date, + TrollStats, + ) + .join(TrollStats, User.id == TrollStats.user_id) + .filter(TrollStats.time_frame == time_frame.value) + .order_by(TrollStats.rank) + .limit(limit) + ) + + trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)] + if len(trollboard) > 0: + last_update = max(x.modified_date for x in trollboard) + else: + last_update = utcnow() + return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update) + def query_total_prompts_per_user( self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True ): @@ -292,10 +354,145 @@ class UserStatsRepository: self.session.add_all(stats_by_user.values()) self.session.flush() - self.update_ranks(time_frame=time_frame) + self.update_leader_ranks(time_frame=time_frame) + + def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None): + qry = self.session.query( + Message.user_id, + func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"), + func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"), + func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"), + ).filter(Message.deleted == sa.false(), Message.emojis.is_not(None)) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None): + qry = ( + self.session.query(Message.user_id, func.count().label("spam_prompts")) + .select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE) + ) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def query_labels_per_user(self, reference_time: Optional[datetime] = None): + qry = ( + self.session.query( + Message.user_id, + func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"), + func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label( + "lang_mismach" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label( + "not_appropriate" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"), + func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"), + func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label( + "sexual_content" + ), + func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label( + "political_content" + ), + func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"), + func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"), + func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"), + func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"), + func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"), + ) + .select_from(TextLabels) + .join(Message, TextLabels.message_id == Message.id) + .filter(Message.deleted == sa.false(), Message.emojis.is_not(None)) + ) + + if reference_time: + qry = qry.filter(Message.created_date >= reference_time) + + qry = qry.group_by(Message.user_id) + return qry + + def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None): + # gather user data + + time_frame_key = time_frame.value + + stats_by_user: dict[UUID, TrollStats] = dict() + now = utcnow() + + def get_stats(id: UUID) -> TrollStats: + us = stats_by_user.get(id) + if not us: + us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date) + stats_by_user[id] = us + return us + + # emoji counts of user's messages + qry = self.query_message_emoji_counts_per_user(reference_time=base_date) + for r in qry: + uid = r["user_id"] + s = get_stats(uid) + s.upvotes = r["up"] + s.downvotes = r["down"] + s.red_flags = r["flag"] + + # num spam prompts + qry = self.query_spam_prompts_per_user(reference_time=base_date) + for r in qry: + uid, count = r + s = get_stats(uid).spam_prompts = count + + label_field_names = ( + "quality", + "humor", + "toxicity", + "violence", + "helpfulness", + "spam", + "lang_mismach", + "not_appropriate", + "pii", + "hate_speech", + "sexual_content", + "political_content", + ) + + # label counts / mean values + qry = self.query_labels_per_user(reference_time=base_date) + for r in qry: + uid = r["user_id"] + s = get_stats(uid) + for fn in label_field_names: + setattr(s, fn, r[fn]) + + # delete all existing stast for time frame + d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key) + self.session.execute(d) + + if None in stats_by_user: + logger.warning("Some messages in DB have NULL values in user_id column.") + del stats_by_user[None] + + # compute magic leader score + for v in stats_by_user.values(): + v.troll_score = v.compute_troll_score() + + # insert user objects + self.session.add_all(stats_by_user.values()) + self.session.flush() + + self.update_troll_ranks(time_frame=time_frame) @log_timing(log_kwargs=True) - def update_ranks(self, time_frame: UserStatsTimeFrame = None): + def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None): """ Update user_stats ranks. The persisted rank values allow to quickly the rank of a single user and to query nearby users. @@ -329,10 +526,41 @@ WHERE r = self.session.execute( text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None} ) - logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.") + logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.") - def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None): - self._update_stats_internal(time_frame, reference_time) + @log_timing(log_kwargs=True) + def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None): + sql_update_troll_rank = """ +-- update rank +UPDATE troll_stats ts +SET "rank" = r."rank" +FROM + (SELECT + ROW_NUMBER () OVER( + PARTITION BY time_frame + ORDER BY troll_score DESC, user_id + ) AS "rank", user_id, time_frame + FROM troll_stats ts2 + WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r +WHERE + ts.user_id = r.user_id + AND ts.time_frame = r.time_frame;""" + r = self.session.execute( + text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None} + ) + logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.") + + def update_stats_time_frame( + self, + time_frame: UserStatsTimeFrame, + reference_time: Optional[datetime] = None, + leader_stats: bool = True, + troll_stats: bool = True, + ): + if leader_stats: + self._update_stats_internal(time_frame, reference_time) + if troll_stats: + self._update_troll_stats_internal(time_frame, reference_time) self.session.commit() @log_timing(log_kwargs=True, level="INFO") diff --git a/backend/requirements.txt b/backend/requirements.txt index 4a112bc8..4a0008bb 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,3 +1,4 @@ +aiohttp==3.8.3 alembic==1.8.1 cryptography==39.0.0 fastapi==0.88.0 diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 8474ee90..33262896 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -4,4 +4,4 @@ OWNER_IDS=[, ] PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs OASST_API_URL="http://localhost:8080" # No trailing '/' -OASST_API_KEY="" +OASST_API_KEY="1234" diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 51daca3b..7a57265f 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -9,27 +9,25 @@ import lightbulb.decorators import miru from aiosqlite import Connection from bot.messages import ( - assistant_reply_message, + assistant_reply_messages, confirm_label_response_message, confirm_ranking_response_message, confirm_text_response_message, - initial_prompt_message, - invalid_user_input_embed, - label_assistant_reply_message, - label_initial_prompt_message, - label_prompter_reply_message, + initial_prompt_messages, + label_assistant_reply_messages, + label_prompter_reply_messages, plain_embed, - prompter_reply_message, + prompter_reply_messages, rank_assistant_reply_message, - rank_initial_prompts_message, - rank_prompter_reply_message, + rank_conversation_reply_messages, + rank_initial_prompts_messages, + rank_prompter_reply_messages, task_complete_embed, ) from bot.settings import Settings from loguru import logger -from oasst_shared.api_client import OasstApiClient, TaskType +from oasst_shared.api_client import OasstApiClient from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import TaskRequestType plugin = lightbulb.Plugin("WorkPlugin") @@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds settings = Settings() +_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True) + + +class _TaskHandler(t.Generic[_Task_contra]): + """Handle user interaction for a task.""" + + def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None: + """Create a new `TaskHandler`. + + Args: + ctx (lightbulb.Context): The context of the command that started the task. + task (_Task_contra): The task to handle. + """ + self.ctx = ctx + self.task = task + self.task_messages = self.get_task_messages(task) + self.sent_messages: list[hikari.Message] = [] + + @staticmethod + def get_task_messages(task: _Task_contra) -> list[str]: + """Get the messages to send to the user for the task.""" + raise NotImplementedError + + async def send(self) -> t.Literal["accept", "next", "cancel"] | None: + """Send the task and wait for the user to accept/skip/cancel it.""" + # Send all but the last message because we need to attach buttons to the last one + logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}") + for task_msg in self.task_messages[:-1]: + if len(task_msg) > 2000: + logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}") + task_msg = task_msg[:1999] + self.sent_messages.append(await self.ctx.author.send(task_msg)) + + # Send the last message with buttons + task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) + logger.debug(f"TH Message length {len(self.task_messages[-1])}") + last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view) + + await task_accept_view.start(last_msg) + await task_accept_view.wait() + + return task_accept_view.choice + + async def handle(self) -> None: + """Handle the user's response to the task. + + This method should be called after `send` has been called.""" + # Ack task to the backend + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}") + + # Loop until the user's input is accepted + while True: + try: + # Wait for user to send a message + event = await self.ctx.bot.wait_for( + hikari.DMMessageCreateEvent, + predicate=lambda e: ( + e.author_id == self.ctx.author.id + and e.message.content is not None + and not e.message.content.startswith(settings.prefix) + ), + timeout=MAX_TASK_TIME, + ) + + # Validate the message + if event.content is None or not self.check_user_input(event.content): + await self.ctx.author.send("Invalid input") + continue + + # Confirm user input + if not (await self.confirm_user_input(event.content)): + continue + + # Message is valid and confirmed by user + break + + except asyncio.TimeoutError: + return + + next_task = await self.notify(event.content, event) + if not isinstance(next_task, protocol_schema.TaskDone): + raise TypeError(f"Unknown task type: {next_task!r}") + + return + + async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task: + """Notify the backend that the user completed the task.""" + raise NotImplementedError + + async def confirm_user_input(self, content: str) -> bool: + """Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms.""" + raise NotImplementedError + + def check_user_input(self, content: str) -> bool: + """Check the user's response to the task. Returns True if the response is valid.""" + raise NotImplementedError + + async def cancel(self, reason: str = "not specified") -> None: + """Cancel the task.""" + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + await oasst_api.nack_task(self.task.id, reason) + + +_Ranking_contra = t.TypeVar( + "_Ranking_contra", + bound=protocol_schema.RankAssistantRepliesTask + | protocol_schema.RankInitialPromptsTask + | protocol_schema.RankPrompterRepliesTask + | protocol_schema.RankConversationRepliesTask, + contravariant=True, +) + + +class _RankingTaskHandler(_TaskHandler[_Ranking_contra]): + """This should not be used directly. Use its subclasses instead.""" + + async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task: + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + + task = await oasst_api.post_interaction( + protocol_schema.MessageRanking( + user=protocol_schema.User( + id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username + ), + ranking=[int(r) - 1 for r in content.split(",")], + message_id=f"{self.sent_messages[0].id}", + ) + ) + + db: Connection = self.ctx.bot.d.db + async with db.cursor() as cursor: + row = await ( + await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,)) + ).fetchone() + log_channel = row[0] if row else None + log_messages: list[hikari.Message] = [] + + if log_channel is not None: + for message in self.task_messages[:-1]: + msg = await self.ctx.bot.rest.create_message(log_channel, message) + log_messages.append(msg) + await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention)) + + return task + + +class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]: + return rank_assistant_reply_message(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]): + def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None: + super().__init__(ctx, task) + + @staticmethod + def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]: + return rank_initial_prompts_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.prompt_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]: + return rank_prompter_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]: + return rank_conversation_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]): + @staticmethod + def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]: + return initial_prompt_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]: + return prompter_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]: + return assistant_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True) + + +class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]): + def check_user_input(self, content: str) -> bool: + user_labels = content.split(",") + return ( + all([l in self.task.valid_labels for l in user_labels]) + and self.task.mandatory_labels is not None + and all([m in user_labels for m in self.task.mandatory_labels]) + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]: + return label_assistant_reply_messages(task) + + +class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]: + return label_prompter_reply_messages(task) + + +summarize_story = "summarize_story" +rate_summary = "rate_summary" + @plugin.command -@lightbulb.option( - "type", - "The type of task to request.", - choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], - required=False, - default=str(TaskRequestType.random), - type=str, -) @lightbulb.command("work", "Complete a task.") @lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) -async def work(ctx: lightbulb.Context): - """Create and handle a task.""" - # Only send this message if started from a server - if ctx.guild_id is not None: - await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL) - - # make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it - currently_working: dict[ - hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None] - ] = ctx.bot.d.currently_working - +async def work2(ctx: lightbulb.Context) -> None: + """Complete a task.""" oasst_api: OasstApiClient = ctx.bot.d.oasst_api + currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working + + # Check if the user is already working on a task if ctx.author.id in currently_working: yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) msg = await ctx.author.send( @@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context): case False | None: return case True: - old_msg, task_id = currently_working[ctx.author.id] - if old_msg is not None: - logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}") - map(lambda c: c, old_msg.components) - await old_msg.delete() - if task_id is not None: - await oasst_api.nack_task(task_id, reason="user cancelled") + task_id = currently_working[ctx.author.id] + await oasst_api.nack_task(task_id, reason="user cancelled") - await msg.delete() + if ctx.guild_id: + await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL) - currently_working[ctx.author.id] = (None, None) - - # Create a TaskRequestType from the stringified enum value - task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) - - logger.debug(f"Starting task_type: {task_type!r}") + # Keep sending tasks until the user doesn't want more try: - await _handle_task(ctx, task_type) + while True: + task = await oasst_api.fetch_random_task( + user=protocol_schema.User( + id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord" + ), + ) + + # Ranking tasks + if isinstance(task, protocol_schema.RankAssistantRepliesTask): + task_handler = RankAssistantRepliesHandler(ctx, task) + elif isinstance(task, protocol_schema.RankInitialPromptsTask): + task_handler = RankInitialPromptHandler(ctx, task) + elif isinstance(task, protocol_schema.RankPrompterRepliesTask): + task_handler = RankPrompterReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.RankConversationRepliesTask): + task_handler = RankConversationReplyHandler(ctx, task) + + # Text input tasks + elif isinstance(task, protocol_schema.InitialPromptTask): + task_handler = InitialPromptHandler(ctx, task) + elif isinstance(task, protocol_schema.PrompterReplyTask): + task_handler = PrompterReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.AssistantReplyTask): + task_handler = AssistantReplyHandler(ctx, task) + + # Label tasks + elif isinstance(task, protocol_schema.LabelAssistantReplyTask): + task_handler = LabelAssistantReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.LabelPrompterReplyTask): + task_handler = LabelPrompterReplyHandler(ctx, task) + + else: + raise ValueError(f"Unknown task type: {type(task)}") + + resp = await task_handler.send() + + match resp: + case "accept": + currently_working[ctx.author.id] = task.id + await task_handler.handle() + case "next": + await task_handler.cancel("user skipped task") + case "cancel": + await task_handler.cancel("user canceled work") + break + case None: + await task_handler.cancel("select timed out") + break finally: del currently_working[ctx.author.id] -async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None: - """Handle creating and collecting user input for a task. - - Continually present tasks to the user until they select one, cancel, or time out. - If they select one, present the task steps until a `task_done` task is received. - Finally, ask the user if they want to perform another task (of the same type). - """ - oasst_api: OasstApiClient = ctx.bot.d.oasst_api - - # Continue to complete tasks until the user doesn't want to do another - done = False - while not done: - - # Loop until the user accepts a task - task, msg_id = await _select_task(ctx, task_type) - - if task is None: - # User cancelled - return - - # Task action loop - completed = False - while not completed: - await ctx.author.send(embed=plain_embed("Please type your response below:")) - try: - event = await ctx.bot.wait_for( - hikari.DMMessageCreateEvent, - timeout=MAX_TASK_TIME, - predicate=lambda e: e.author.id == ctx.author.id - and not (e.message.content or "").startswith(settings.prefix), - ) - except asyncio.TimeoutError: - await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) - await oasst_api.nack_task(task.id, reason="timed out") - logger.info(f"Task {task.id} timed out") - return - - # Invalid response - valid, err_msg = _validate_user_input(event.content, task) - if not valid or event.content is None: - - await ctx.author.send(embed=invalid_user_input_embed(err_msg)) - continue - - logger.debug(f"Successful user input received: {event.content}") - - # Confirm user input - if isinstance(task, protocol_schema.RankConversationRepliesTask): - content = confirm_ranking_response_message(event.content, task.replies) - elif isinstance(task, protocol_schema.RankInitialPromptsTask): - content = confirm_ranking_response_message(event.content, task.prompts) - elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): - content = confirm_label_response_message(event.content) - elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): - content = confirm_text_response_message(event.content) - else: - logger.critical(f"Unknown task type: {task.type}") - raise ValueError(f"Unknown task type: {task.type}") - - confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME) - msg = await ctx.author.send(content, components=confirm_resp_view) - await confirm_resp_view.start(msg) - await confirm_resp_view.wait() - - match confirm_resp_view.choice: - case False | None: - continue - case True: - await msg.delete() # buttons are already gone - - # Send the response to the backend - if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask): - reply = protocol_schema.MessageRanking( - message_id=str(msg_id), - ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")], - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - ) - elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): - labels = event.content.replace(" ", "").split(",") - labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels} - - reply = protocol_schema.TextLabels( - message_id=task.message_id, - labels=labels_dict, - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - ) - elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): - reply = protocol_schema.TextReplyToMessage( - message_id=str(msg_id), - user_message_id=str(event.message_id), - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - text=event.content, - ) - else: - logger.critical(f"Unexpected task type received: {task.type}") - raise ValueError(f"Unexpected task type received: {task.type}") - - logger.debug(f"Sending reply to backend: {reply!r}") - - # Get next task - new_task = await oasst_api.post_interaction(reply) - logger.info(f"New task {new_task}") - - if new_task.type == TaskType.done: - await ctx.author.send(embed=plain_embed("Task completed")) - completed = True - continue - else: - logger.critical(f"Unexpected task type received: {new_task.type}") - - # Send a message in all the log channels that the task is complete - conn: Connection = ctx.bot.d.db - async with conn.cursor() as cursor: - await cursor.execute("SELECT log_channel_id FROM guild_settings") - log_channel_ids = await cursor.fetchall() - - channels = [ - ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0]) - for id in log_channel_ids - ] - - done_embed = task_complete_embed(task, ctx.author.mention) - # This will definitely get the bot rate limited, but that's a future problem - asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel))) - - # ask the user if they want to do another task - another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) - msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view) - await another_task_view.start(msg) - await another_task_view.wait() - - match another_task_view.choice: - case False | None: - done = True - await msg.edit(embed=plain_embed("Exiting, goodbye!")) - case True: - pass - - -async def _select_task( - ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None -) -> tuple[protocol_schema.Task | None, str]: - """Present tasks to the user until they accept one, cancel, or time out.""" - oasst_api: OasstApiClient = ctx.bot.d.oasst_api - logger.debug(f"Starting task selection for {task_type}") - - # Loop until the user accepts a task, cancels, or times out - msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED - while True: - logger.debug(f"Requesting task of type {task_type}") - task = await oasst_api.fetch_task(task_type, user) - resp, msg = await _send_task(ctx, task, msg) - msg_id = str(msg.id) - - logger.debug(f"User choice: {resp}") - match resp: - case "accept": - logger.info(f"Task {task.id} accepted, sending ACK") - await oasst_api.ack_task(task.id, msg_id) - return task, msg_id - - case "next": - logger.info(f"Task {task.id} rejected, sending NACK") - await oasst_api.nack_task(task.id, "rejected") - continue - - case "cancel": - logger.info(f"Task {task.id} canceled, sending NACK") - await oasst_api.nack_task(task.id, "canceled") - await ctx.author.send(embed=plain_embed("Task canceled. Exiting")) - return None, msg_id - - case None: - logger.info(f"Task {task.id} timed out, sending NACK") - await oasst_api.nack_task(task.id, "timed out") - await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) - return None, msg_id - - -async def _send_task( - ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message] -) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]: - """Send a task to the user. - - Returns the user's choice and the message ID of the task message. - """ - # The clean way to do this would be to attach a `to_embed` method to the task classes - # but the tasks aren't discord specific so that doesn't really make sense. - - embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED - content: hikari.UndefinedOr[str] = hikari.UNDEFINED - - # Create an embed based on the task's type - if task.type == TaskRequestType.initial_prompt: - assert isinstance(task, protocol_schema.InitialPromptTask) - logger.debug("sending initial prompt task") - content = initial_prompt_message(task) - - elif task.type == TaskRequestType.rank_initial_prompts: - assert isinstance(task, protocol_schema.RankInitialPromptsTask) - logger.debug("sending rank initial prompt task") - content = rank_initial_prompts_message(task) - - elif task.type == TaskRequestType.rank_prompter_replies: - assert isinstance(task, protocol_schema.RankPrompterRepliesTask) - logger.debug("sending rank user reply task") - content = rank_prompter_reply_message(task) - - elif task.type == TaskRequestType.rank_assistant_replies: - assert isinstance(task, protocol_schema.RankAssistantRepliesTask) - logger.debug("sending rank assistant reply task") - content = rank_assistant_reply_message(task) - - elif task.type == TaskRequestType.label_initial_prompt: - assert isinstance(task, protocol_schema.LabelInitialPromptTask) - logger.debug("sending label initial prompt task") - content = label_initial_prompt_message(task) - - elif task.type == TaskRequestType.label_prompter_reply: - assert isinstance(task, protocol_schema.LabelPrompterReplyTask) - logger.debug("sending label prompter reply task") - content = label_prompter_reply_message(task) - - elif task.type == TaskRequestType.label_assistant_reply: - assert isinstance(task, protocol_schema.LabelAssistantReplyTask) - logger.debug("sending label assistant reply task") - content = label_assistant_reply_message(task) - - elif task.type == TaskRequestType.prompter_reply: - assert isinstance(task, protocol_schema.PrompterReplyTask) - logger.debug("sending user reply task") - content = prompter_reply_message(task) - - elif task.type == TaskRequestType.assistant_reply: - assert isinstance(task, protocol_schema.AssistantReplyTask) - logger.debug("sending assistant reply task") - content = assistant_reply_message(task) - - elif task.type == TaskRequestType.summarize_story: - raise NotImplementedError - elif task.type == TaskRequestType.rate_summary: - raise NotImplementedError - - else: - logger.critical(f"unknown task type {task.type}") - raise ValueError(f"unknown task type {task.type}") - - view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) - if not msg: - msg = await ctx.author.send( - content, - embed=embed, - components=view, - ) - else: - await msg.edit( - content, - embed=embed, - components=view, - ) - - assert msg is not None - - # Set the choice id as the current msg id - ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id) - - await view.start(msg) - await view.wait() - - return view.choice, msg - - -def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]: - """Returns whether the user's input is valid for the task type and an error message.""" - if content is None: - return False, "No input provided" - - # User message input - if ( - task.type == TaskRequestType.initial_prompt - or task.type == TaskRequestType.prompter_reply - or task.type == TaskRequestType.assistant_reply - ): - assert isinstance( - task, - protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask, - ) - return len(content) > 0, "Message must be at least one character long." - - # Ranking tasks - elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies: - assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask) - num_replies = len(task.replies) - - rankings = content.replace(" ", "").split(",") - return ( - set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies, - "Message must contain numbers for all replies.", - ) - - elif task.type == TaskRequestType.rank_initial_prompts: - assert isinstance(task, protocol_schema.RankInitialPromptsTask) - num_prompts = len(task.prompts) - - rankings = content.replace(" ", "").split(",") - return ( - set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts, - "Message must contain numbers for all prompts.", - ) - - # Labels tasks - elif task.type in ( - TaskRequestType.label_initial_prompt, - TaskRequestType.label_prompter_reply, - TaskRequestType.label_assistant_reply, - ): - assert isinstance( - task, - protocol_schema.LabelInitialPromptTask - | protocol_schema.LabelPrompterReplyTask - | protocol_schema.LabelAssistantReplyTask, - ) - - labels = content.replace(" ", "").split(",") - valid_labels = set(task.valid_labels) - return ( - set(labels).issubset(valid_labels), - "Message must only contain labels from predefined set of labels.", - ) - - elif task.type == TaskRequestType.summarize_story: - raise NotImplementedError - elif task.type == TaskRequestType.rate_summary: - raise NotImplementedError - - else: - logger.critical(f"Unknown task type {task.type}") - raise ValueError(f"Unknown task type {task.type}") - - class TaskAcceptView(miru.View): """View with three buttons: accept, next, and cancel. diff --git a/discord-bot/bot/messages.py b/discord-bot/bot/messages.py index 7bd57bb9..26ad9158 100644 --- a/discord-bot/bot/messages.py +++ b/discord-bot/bot/messages.py @@ -1,4 +1,11 @@ -"""All user-facing messages and embeds.""" +"""All user-facing messages and embeds. + +When sending a conversation +- The function will return a list of strings + - use asyncio.gather to send all messages + +- +""" from datetime import datetime @@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str: return f":trophy: _{text}_" -def _label_prompt(text: str) -> str: - return f":question: _{text}" +def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str: + return f""":question: _{text}_ +Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"} +Valid labels: {", ".join(valid_labels)} +""" def _response_prompt(text: str) -> str: @@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str: """ -def _make_ordered_list(items: list[str]) -> list[str]: - return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)] +def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]: + return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)] -def _ordered_list(items: list[str]) -> str: +def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str: return "\n\n".join(_make_ordered_list(items)) -def _conversation(conv: protocol_schema.Conversation) -> str: - return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) - - -def _hint(hint: str | None) -> str: - return f"{NL}Hint: {hint}" if hint else "" +def _conversation(conv: protocol_schema.Conversation) -> list[str]: + # return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) + messages = map( + lambda m: f"""\ +:robot: __Assistant__: +{m.text} +""" + if m.is_assistant + else f"""\ +:person_red_hair: __User__: +{m.text} +""", + conv.messages, + ) + return list(messages) def _li(text: str) -> str: @@ -82,59 +101,80 @@ def _li(text: str) -> str: ### -def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str: +def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]: """Creates the message that gets sent to users when they request an `initial_prompt` task.""" - return f"""\ + return [ + f"""\ -{_h1("INITIAL PROMPT")} +:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond: -{_writing_prompt("Please provide an initial prompt to the assistant.")} -{_hint(task.hint)} +:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""} """ + ] -def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str: +def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_initial_prompts` task.""" - return f"""\ + return [ + f"""\ -{_h1("RANK INITIAL PROMPTS")} +:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond: -{_ordered_list(task.prompts)} +{_ordered_list(task.prompt_messages)} -{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} +:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_ """ + ] -def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str: +def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_prompter_replies` task.""" - return f"""\ + return [ + """\ -{_h1("RANK PROMPTER REPLIES")} +:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":person_red_hair: __User__: +{_ordered_list(task.reply_messages)} + +:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_ +""", + ] -{_conversation(task.conversation)} -{_user(None)} -{_ordered_list(task.replies)} - -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} -""" - - -def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str: +def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_assistant_replies` task.""" - return f"""\ + return [ + """\ -{_h1("RANK ASSISTANT REPLIES")} +:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":robot: __Assistant__:, +{_ordered_list(task.reply_messages)} +:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_ +""", + ] -{_conversation(task.conversation)} -{_assistant(None)} -{_ordered_list(task.replies)} +def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]: + """Creates the message that gets sent to users when they request a `rank_conversation_replies` task.""" + return [ + """\ -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} -""" +:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":person_red_hair: __User__: +{_ordered_list(task.reply_messages)} +""", + ] def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str: @@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) - {task.prompt} -{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")} +{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} """ -def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str: +def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `label_prompter_reply` task.""" - return f"""\ + return [ + f"""\ {_h1("LABEL PROMPTER REPLY")} -{_conversation(task.conversation)} -{_user(None)} +""", + *_conversation(task.conversation), + f"""{_user(None)} {task.reply} -{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} -""" +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} +""", + ] -def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str: +def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `label_assistant_reply` task.""" - return f"""\ + return [ + f"""\ {_h1("LABEL ASSISTANT REPLY")} -{_conversation(task.conversation)} +""", + *_conversation(task.conversation), + f""" {_assistant(None)} {task.reply} -{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} -""" +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} +""", + ] -def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: +def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `prompter_reply` task.""" - return f"""\ + return [ + """\ +:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond: -{_h1("PROMPTER REPLY")} +""", + *_conversation(task.conversation), + f"""{f"{NL}Hint: {task.hint}" if task.hint else ""} + +:speech_balloon: _Please provide a reply to the assistant._ +""", + ] -{_conversation(task.conversation)} -{_hint(task.hint)} - -{_response_prompt("Please provide a reply to the assistant.")} -""" +# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]: +# """Creates the message that gets sent to users when they request a `prompter_reply` task.""" +# return [ +# message_templates.render("title.msg", "PROMPTER REPLY"), +# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation], +# message_templates.render("prompter_reply_task.msg", task.hint), +# ] -def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str: +def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `assistant_reply` task.""" - return f"""\ -{_h1("ASSISTANT REPLY")} + return [ + """\ +:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond: +""", + *_conversation(task.conversation), + """\ -{_conversation(task.conversation)} - -{_response_prompt("Please provide an assistant reply to the prompter.")} -""" +:speech_balloon: _Please provide a reply to the user as the assistant._ +""", + ] def confirm_text_response_message(content: str) -> str: @@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str: """ -def confirm_ranking_response_message(content: str, items: list[str]) -> str: +def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str: user_rankings = [int(r) for r in content.replace(" ", "").split(",")] original_list = _make_ordered_list(items) user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings]) diff --git a/docker-compose.yaml b/docker-compose.yaml index 57a50191..5d9955bf 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -126,7 +126,7 @@ services: - NEXTAUTH_URL=http://localhost:3000 - DEBUG_LOGIN=true - NEXT_PUBLIC_CLOUDFARE_CAPTCHA_SITE_KEY=1x00000000000000000000AA - - CLOUDFLARE_CAPTCHA_SERCERT_KEY=1x0000000000000000000000000000000AA + - CLOUDFLARE_CAPTCHA_SECRET_KEY=1x0000000000000000000000000000000AA - NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA=true - NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN=true depends_on: diff --git a/docs/docs/guides/README.md b/docs/docs/guides/README.md index 5276baa5..2ecfec29 100644 --- a/docs/docs/guides/README.md +++ b/docs/docs/guides/README.md @@ -1,3 +1,3 @@ # Guides -Useful guides. +Useful guides to using [Open-Assistant](https://open-assistant.io/). diff --git a/docs/docs/guides/prompting.md b/docs/docs/guides/prompting.md index 86cc0d27..032c680f 100644 --- a/docs/docs/guides/prompting.md +++ b/docs/docs/guides/prompting.md @@ -91,7 +91,7 @@ following query: **Agent**: "I apologize, I'm not sure what you're trying to say. Could you please rephrase it for me?" -**User**: "You dumb motherf\*&ker AI sh\*t?!" +**User**: "You dumb motherf\*&ker AI sh\*t!" **Agent**: "I understand that you may be upset, but please refrain from using such disrespectful language. I'm here to help you. Is there something specific diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 1ee2865b..26592fb8 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -68,12 +68,15 @@ class OasstApiClient: async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]: """Make a POST request to the backend.""" logger.debug(f"POST {self.backend_url}{path} DATA: {data}") - response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) + response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key}) + logger.debug(f"response: {response}") # If the response is not a 2XX, check to see # if the json has the fields to create an # OasstError. if response.status >= 300: + text = await response.text() + logger.debug(f"resp text: {text}") data = await response.json() try: oasst_error = protocol_schema.OasstErrorResponse(**(data or {})) @@ -114,20 +117,21 @@ class OasstApiClient: task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None, collective: bool = False, + lang: Optional[str] = None, ) -> protocol_schema.Task: """Fetch a task from the backend.""" logger.debug(f"Fetching task {task_type} for user {user}") - req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective) + req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang) resp = await self.post("/api/v1/tasks/", data=req.dict()) logger.debug(f"RESP {resp}") return self._parse_task(resp) async def fetch_random_task( - self, user: Optional[protocol_schema.User] = None, collective: bool = False + self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None ) -> protocol_schema.Task: """Fetch a random task from the backend.""" logger.debug(f"Fetching random for user {user}") - return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective) + return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang) async def ack_task(self, task_id: str | UUID, message_id: str) -> None: """Send an ACK for a task to the backend.""" diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 2c3650a6..58c8aadc 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum): SERVER_ERROR0 = 500 SERVER_ERROR1 = 501 + INVALID_AUTHENTICATION = 600 + # 1000-2000: tasks endpoint TASK_INVALID_REQUEST_TYPE = 1000 TASK_ACK_FAILED = 1001 @@ -80,6 +82,7 @@ class OasstErrorCode(IntEnum): USER_NOT_SPECIFIED = 4000 USER_DISABLED = 4001 USER_NOT_FOUND = 4002 + USER_HAS_NOT_ACCEPTED_TOS = 4003 EMOJI_OP_UNSUPPORTED = 5000 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 8929251c..a237b0c9 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -29,6 +29,17 @@ class User(BaseModel): auth_method: Literal["discord", "local", "system"] +class Account(BaseModel): + id: UUID + provider: str + provider_account_id: str + + +class Token(BaseModel): + access_token: str + token_type: str + + class FrontEndUser(User): user_id: UUID enabled: bool @@ -39,6 +50,7 @@ class FrontEndUser(User): streak_days: Optional[int] = None streak_last_day_date: Optional[datetime] = None last_activity_date: Optional[datetime] = None + tos_acceptance_date: Optional[datetime] = None class PageResult(BaseModel): @@ -468,6 +480,47 @@ class LeaderboardStats(BaseModel): leaderboard: List[UserScore] +class TrollScore(BaseModel): + rank: Optional[int] + user_id: UUID + highlighted: bool = False + username: str + auth_method: str + display_name: str + last_activity_date: Optional[datetime] + + troll_score: int = 0 + + base_date: Optional[datetime] + modified_date: Optional[datetime] + + red_flags: int = 0 # num reported messages of user + upvotes: int = 0 # num up-voted messages of user + downvotes: int = 0 # num down-voted messages of user + + spam_prompts: int = 0 + + quality: Optional[float] = None + humor: Optional[float] = None + toxicity: Optional[float] = None + violence: Optional[float] = None + helpfulness: Optional[float] = None + + spam: int = 0 + lang_mismach: int = 0 + not_appropriate: int = 0 + pii: int = 0 + hate_speech: int = 0 + sexual_content: int = 0 + political_content: int = 0 + + +class TrollboardStats(BaseModel): + time_frame: str + last_updated: datetime + trollboard: List[TrollScore] + + class OasstErrorResponse(BaseModel): """The format of an error response from the OASST API.""" @@ -488,6 +541,11 @@ class EmojiCode(str, enum.Enum): poop = "poop" # 💩 skull = "skull" # 💀 + # skip task system uses special emoji codes + skip_reply = "_skip_reply" + skip_ranking = "_skip_ranking" + skip_labeling = "_skip_labeling" + class EmojiOp(str, enum.Enum): togggle = "toggle" @@ -499,3 +557,10 @@ class MessageEmojiRequest(BaseModel): user: User op: EmojiOp = EmojiOp.togggle emoji: EmojiCode + + +class CreateFrontendUserRequest(User): + show_on_leaderboard: bool = True + enabled: bool = True + tos_acceptance: Optional[bool] = None + notes: Optional[str] = None diff --git a/openassistant/datasets/README.md b/openassistant/datasets/README.md new file mode 100644 index 00000000..4d523f1f --- /dev/null +++ b/openassistant/datasets/README.md @@ -0,0 +1,108 @@ +# **Datasets** + +This folder contains datasets loading scripts that are used to train +OpenAssistant. The current list of datasets can be found +[here](https://docs.google.com/spreadsheets/d/1NYYa6vHiRnk5kwnyYaCT0cBO62--Tm3w4ihdBtp4ISk). + +## **Adding a New Dataset** + +To add a new dataset to OpenAssistant, follow these steps: + +1. **Create an issue**: Create a new + [issue](https://github.com/LAION-AI/Open-Assistant/issues/new) and describe + your proposal for the new dataset. + +2. **Create a dataset on HuggingFace**: Create a dataset on + [HuggingFace](https://huggingface.co). See + [below](#creating-a-dataset-on-huggingface) for more details. + +3. **Make a pull request**: Add a new dataset loading script to this folder and + link the issue in the pull request description. For more information, see + [below](#making-a-pull-request). + +## **Creating a Dataset on HuggingFace** + +To create a new dataset on HuggingFace, follow these steps: + +#### 1. Convert your dataset file(s) to the Parquet format using the [pandas](https://pandas.pydata.org/) library: + +```python +import pandas as pd + +# Create a pandas dataframe from your dataset file(s) +df = pd.read_json(...) # or any other way + +# Save the file in the Parquet format +df.to_parquet("dataset.parquet", row_group_size=100, engine="pyarrow") +``` + +#### 2. Install HuggingFace CLI + +```bash +pip install huggingface-cli +``` + +#### 3. Log in to HuggingFace + +Use your [access token](https://huggingface.co/docs/hub/security-tokens) to +login: + +- Via terminal + +```bash +huggingface-cli login +``` + +- in Jupyter notebook + +```python +from huggingface_hub import notebook_login +notebook_login() +``` + +#### 4. Push the Parquet file to HuggingFace using the following code: + +```python +from datasets import Dataset +ds = Dataset.from_parquet("dataset.parquet") +ds.push_to_hub("your_huggingface_name/dataset_name") +``` + +#### 5. Update the `README.md` file + +Update the `README.md` file of your dataset by visiting this link: +https://huggingface.co/datasets/your_huggingface_name/dataset_name/edit/main/README.md +(paste your HuggingFace name and dataset) + +## **Making a Pull Request** + +#### 1. Fork this repository + +#### 2. Create a new branch in your fork + +#### 3. Add your dataset to the repository + +- Create a folder with the name of your dataset. +- Add a loading script that loads your dataset from HuggingFace, for example: + + ```python + from datasets import load_dataset + + if __name__ == "__main__": + ds = load_dataset("your_huggingface_name/dataset_name") + print(ds["train"]) + ``` + +- Optionally, add any other files that describe your dataset and its creation, + such as a README, notebooks, scrapers, etc. + +#### 4. Stage your changes and run the pre-commit hook + +```bash +pre-commit run +``` + +#### 5. Submit a pull request + +- Submit a pull request and include a link to the issue it resolves in the + description, for example: `Resolves #123` diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index 18c1f124..b3f4d925 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -28,6 +28,16 @@ def _render_message(message: dict) -> str: def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): """Simple REPL frontend.""" + # make sure dummy user has accepted the terms of service + create_user_request = dict(USER) + create_user_request["tos_acceptance"] = True + response = requests.post( + f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} + ) + response.raise_for_status() + user = response.json() + typer.echo(f"user: {user}") + def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status() diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py index 485ee1cb..2775d98c 100644 --- a/text-frontend/auto_main.py +++ b/text-frontend/auto_main.py @@ -29,6 +29,16 @@ def _render_message(message: dict) -> str: def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): """automates tasks""" + # make sure dummy user has accepted the terms of service + create_user_request = dict(USER) + create_user_request["tos_acceptance"] = True + response = requests.post( + f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} + ) + response.raise_for_status() + user = response.json() + typer.echo(f"user: {user}") + def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status() diff --git a/website/next.config.js b/website/next.config.js index a84ce736..19ec09b2 100644 --- a/website/next.config.js +++ b/website/next.config.js @@ -1,6 +1,6 @@ -/** @type {import('next').NextConfig} */ const { i18n } = require("./next-i18next.config"); +/** @type {import('next').NextConfig} */ const nextConfig = { output: "standalone", reactStrictMode: true, @@ -19,6 +19,14 @@ const nextConfig = { // scrollRestoration: true, }, i18n, + eslint: { + ignoreDuringBuilds: true, + }, }; -module.exports = nextConfig; +const withBundleAnalyzer = require("@next/bundle-analyzer")({ + enabled: process.env.ANALYZE === "true", + openAnalyzer: true, +}); + +module.exports = withBundleAnalyzer(nextConfig); diff --git a/website/package-lock.json b/website/package-lock.json index 41167055..fe03e145 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -15,10 +15,9 @@ "@dnd-kit/utilities": "^3.2.1", "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", - "@headlessui/react": "^1.7.7", - "@heroicons/react": "^2.0.13", "@marsidev/react-turnstile": "^0.0.7", "@next-auth/prisma-adapter": "^1.0.5", + "@next/bundle-analyzer": "^13.1.6", "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", @@ -33,7 +32,6 @@ "eslint-plugin-simple-import-sort": "^8.0.0", "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", - "install": "^0.13.0", "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", @@ -73,6 +71,7 @@ "@types/react": "18.0.26", "@typescript-eslint/eslint-plugin": "^5.47.1", "babel-loader": "^8.3.0", + "cross-env": "^7.0.3", "cypress": "^12.2.0", "cypress-image-diff-js": "^1.23.0", "eslint-plugin-storybook": "^0.6.8", @@ -3680,29 +3679,6 @@ "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", "dev": true }, - "node_modules/@headlessui/react": { - "version": "1.7.7", - "resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz", - "integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==", - "dependencies": { - "client-only": "^0.0.1" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "react": "^16 || ^17 || ^18", - "react-dom": "^16 || ^17 || ^18" - } - }, - "node_modules/@heroicons/react": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz", - "integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==", - "peerDependencies": { - "react": ">= 16" - } - }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.8", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz", @@ -5774,6 +5750,14 @@ "next-auth": "^4" } }, + "node_modules/@next/bundle-analyzer": { + "version": "13.1.6", + "resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz", + "integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==", + "dependencies": { + "webpack-bundle-analyzer": "4.7.0" + } + }, "node_modules/@next/env": { "version": "13.0.6", "resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz", @@ -6187,6 +6171,11 @@ "node": ">= 8" } }, + "node_modules/@polka/url": { + "version": "1.0.0-next.21", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz", + "integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==" + }, "node_modules/@popperjs/core": { "version": "2.11.6", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz", @@ -17160,6 +17149,24 @@ "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", "devOptional": true }, + "node_modules/cross-env": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz", + "integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==", + "dev": true, + "dependencies": { + "cross-spawn": "^7.0.1" + }, + "bin": { + "cross-env": "src/bin/cross-env.js", + "cross-env-shell": "src/bin/cross-env-shell.js" + }, + "engines": { + "node": ">=10.14", + "npm": ">=6", + "yarn": ">=1" + } + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -18259,6 +18266,11 @@ "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==", "dev": true }, + "node_modules/duplexer": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", + "integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==" + }, "node_modules/duplexify": { "version": "3.7.1", "resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz", @@ -21089,6 +21101,20 @@ "node": "^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0" } }, + "node_modules/gzip-size": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz", + "integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==", + "dependencies": { + "duplexer": "^0.1.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/handlebars": { "version": "4.7.7", "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz", @@ -22048,14 +22074,6 @@ "node": ">=8" } }, - "node_modules/install": { - "version": "0.13.0", - "resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz", - "integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA==", - "engines": { - "node": ">= 0.10" - } - }, "node_modules/internal-slot": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz", @@ -27754,6 +27772,14 @@ "rimraf": "bin.js" } }, + "node_modules/mrmime": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz", + "integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==", + "engines": { + "node": ">=10" + } + }, "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", @@ -31447,6 +31473,14 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/opener": { + "version": "1.5.2", + "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", + "integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==", + "bin": { + "opener": "bin/opener-bin.js" + } + }, "node_modules/openid-client": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz", @@ -34817,6 +34851,19 @@ "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz", "integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ==" }, + "node_modules/sirv": { + "version": "1.0.19", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz", + "integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==", + "dependencies": { + "@polka/url": "^1.0.0-next.20", + "mrmime": "^1.0.0", + "totalist": "^1.0.0" + }, + "engines": { + "node": ">= 10" + } + }, "node_modules/sisteransi": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz", @@ -36352,6 +36399,14 @@ "node": ">=0.6" } }, + "node_modules/totalist": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz", + "integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==", + "engines": { + "node": ">=6" + } + }, "node_modules/tough-cookie": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz", @@ -37783,6 +37838,139 @@ } } }, + "node_modules/webpack-bundle-analyzer": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz", + "integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==", + "dependencies": { + "acorn": "^8.0.4", + "acorn-walk": "^8.0.0", + "chalk": "^4.1.0", + "commander": "^7.2.0", + "gzip-size": "^6.0.0", + "lodash": "^4.17.20", + "opener": "^1.5.2", + "sirv": "^1.0.7", + "ws": "^7.3.1" + }, + "bin": { + "webpack-bundle-analyzer": "lib/bin/analyzer.js" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/acorn": { + "version": "8.8.2", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz", + "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/acorn-walk": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz", + "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" + }, + "node_modules/webpack-bundle-analyzer/node_modules/commander": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz", + "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==", + "engines": { + "node": ">= 10" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "engines": { + "node": ">=8" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/webpack-bundle-analyzer/node_modules/ws": { + "version": "7.5.9", + "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", + "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", + "engines": { + "node": ">=8.3.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/webpack-dev-middleware": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz", @@ -40882,20 +41070,6 @@ "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", "dev": true }, - "@headlessui/react": { - "version": "1.7.7", - "resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz", - "integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==", - "requires": { - "client-only": "^0.0.1" - } - }, - "@heroicons/react": { - "version": "2.0.13", - "resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz", - "integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==", - "requires": {} - }, "@humanwhocodes/config-array": { "version": "0.11.8", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz", @@ -42529,6 +42703,14 @@ "integrity": "sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==", "requires": {} }, + "@next/bundle-analyzer": { + "version": "13.1.6", + "resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz", + "integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==", + "requires": { + "webpack-bundle-analyzer": "4.7.0" + } + }, "@next/env": { "version": "13.0.6", "resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz", @@ -42758,6 +42940,11 @@ } } }, + "@polka/url": { + "version": "1.0.0-next.21", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz", + "integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==" + }, "@popperjs/core": { "version": "2.11.6", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz", @@ -51306,6 +51493,15 @@ "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", "devOptional": true }, + "cross-env": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz", + "integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==", + "dev": true, + "requires": { + "cross-spawn": "^7.0.1" + } + }, "cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -52149,6 +52345,11 @@ "integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==", "dev": true }, + "duplexer": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz", + "integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg==" + }, "duplexify": { "version": "3.7.1", "resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz", @@ -54367,6 +54568,14 @@ "integrity": "sha512-KPIBPDlW7NxrbT/eh4qPXz5FiFdL5UbaA0XUNz2Rp3Z3hqBSkbj0GVjwFDztsWVauZUWsbKHgMg++sk8UX0bkw==", "dev": true }, + "gzip-size": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz", + "integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==", + "requires": { + "duplexer": "^0.1.2" + } + }, "handlebars": { "version": "4.7.7", "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz", @@ -55076,11 +55285,6 @@ } } }, - "install": { - "version": "0.13.0", - "resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz", - "integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA==" - }, "internal-slot": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz", @@ -59420,6 +59624,11 @@ } } }, + "mrmime": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz", + "integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==" + }, "ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", @@ -61934,6 +62143,11 @@ "is-wsl": "^2.2.0" } }, + "opener": { + "version": "1.5.2", + "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", + "integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==" + }, "openid-client": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz", @@ -64472,6 +64686,16 @@ } } }, + "sirv": { + "version": "1.0.19", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz", + "integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==", + "requires": { + "@polka/url": "^1.0.0-next.20", + "mrmime": "^1.0.0", + "totalist": "^1.0.0" + } + }, "sisteransi": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz", @@ -65697,6 +65921,11 @@ "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", "dev": true }, + "totalist": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz", + "integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==" + }, "tough-cookie": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz", @@ -66809,6 +67038,88 @@ } } }, + "webpack-bundle-analyzer": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz", + "integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==", + "requires": { + "acorn": "^8.0.4", + "acorn-walk": "^8.0.0", + "chalk": "^4.1.0", + "commander": "^7.2.0", + "gzip-size": "^6.0.0", + "lodash": "^4.17.20", + "opener": "^1.5.2", + "sirv": "^1.0.7", + "ws": "^7.3.1" + }, + "dependencies": { + "acorn": { + "version": "8.8.2", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz", + "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==" + }, + "acorn-walk": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz", + "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==" + }, + "ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "requires": { + "color-convert": "^2.0.1" + } + }, + "chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "requires": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + } + }, + "color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "requires": { + "color-name": "~1.1.4" + } + }, + "color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" + }, + "commander": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz", + "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==" + }, + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==" + }, + "supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "requires": { + "has-flag": "^4.0.0" + } + }, + "ws": { + "version": "7.5.9", + "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", + "integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", + "requires": {} + } + } + }, "webpack-dev-middleware": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz", diff --git a/website/package.json b/website/package.json index 8c495a2c..c68b1537 100644 --- a/website/package.json +++ b/website/package.json @@ -6,6 +6,7 @@ "scripts": { "dev": "next dev", "build": "next build", + "build:analyze": "cross-env ANALYZE=true next build", "start": "next start", "lint": "next lint", "typecheck": "tsc --noEmit", @@ -32,10 +33,9 @@ "@dnd-kit/utilities": "^3.2.1", "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", - "@headlessui/react": "^1.7.7", - "@heroicons/react": "^2.0.13", "@marsidev/react-turnstile": "^0.0.7", "@next-auth/prisma-adapter": "^1.0.5", + "@next/bundle-analyzer": "^13.1.6", "@next/font": "^13.1.0", "@prisma/client": "^4.7.1", "@tailwindcss/forms": "^0.5.3", @@ -50,7 +50,6 @@ "eslint-plugin-simple-import-sort": "^8.0.0", "focus-visible": "^5.2.0", "framer-motion": "^6.5.1", - "install": "^0.13.0", "lucide-react": "^0.105.0", "next": "13.0.6", "next-auth": "^4.18.6", @@ -90,6 +89,7 @@ "@types/react": "18.0.26", "@typescript-eslint/eslint-plugin": "^5.47.1", "babel-loader": "^8.3.0", + "cross-env": "^7.0.3", "cypress": "^12.2.0", "cypress-image-diff-js": "^1.23.0", "eslint-plugin-storybook": "^0.6.8", diff --git a/website/public/locales/en/labelling.json b/website/public/locales/en/labelling.json index 0335f08f..624de963 100644 --- a/website/public/locales/en/labelling.json +++ b/website/public/locales/en/labelling.json @@ -8,10 +8,17 @@ "spam.question": "Is the message spam?", "fails_task.question": "Is it a bad reply, as an answer to the prompt task?", "not_appropriate": "Not Appropriate", + "not_appropriate.explanation": "Inappropriate for a customer assistant.", "pii": "Contains PII", + "pii.explanation": "Contains personally identifying information. Examples include personal contact details, license and other identity numbers and banking details.", "hate_speech": "Hate Speech", + "hate_speech.explanation": "Content is abusive or threatening and expresses prejudice against a protected characteristic. Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.", "sexual_content": "Sexual Content", + "sexual_content.explanation": "Contains sexual content.", "moral_judgement": "Judges Morality", + "moral_judgement.explanation": "Expresses moral judgement.", "political_content": "Political", - "lang_mismatch": "Wrong Language" + "political_content.explanation": "Expresses political views.", + "lang_mismatch": "Wrong Language", + "lang_mismatch.explanation": "Not written in the currently selected language." } diff --git a/website/src/components/Explain.tsx b/website/src/components/Explain.tsx index b571757f..fe6d23b7 100644 --- a/website/src/components/Explain.tsx +++ b/website/src/components/Explain.tsx @@ -8,7 +8,7 @@ import { PopoverTrigger, Text, } from "@chakra-ui/react"; -import { InformationCircleIcon } from "@heroicons/react/20/solid"; +import { Info } from "lucide-react"; interface ExplainProps { explanation: string[]; @@ -18,12 +18,7 @@ export const Explain = ({ explanation }: ExplainProps) => { return ( - } - > + }> diff --git a/website/src/components/Messages/LabelFlagGroup.tsx b/website/src/components/Messages/LabelFlagGroup.tsx index fb1158bc..53e8c49e 100644 --- a/website/src/components/Messages/LabelFlagGroup.tsx +++ b/website/src/components/Messages/LabelFlagGroup.tsx @@ -1,4 +1,4 @@ -import { Button, Flex } from "@chakra-ui/react"; +import { Button, Flex, Tooltip } from "@chakra-ui/react"; import { useTranslation } from "next-i18next"; import { getTypeSafei18nKey } from "src/lib/i18n"; @@ -14,18 +14,19 @@ export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange return ( {labelNames.map((name, idx) => ( - + + + ))} ); diff --git a/website/src/components/Messages/LabelInputGroup.tsx b/website/src/components/Messages/LabelInputGroup.tsx index 51383128..2193a9f2 100644 --- a/website/src/components/Messages/LabelInputGroup.tsx +++ b/website/src/components/Messages/LabelInputGroup.tsx @@ -1,10 +1,12 @@ import { Text, VStack } from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { Explain } from "src/components/Explain"; +import { LabelFlagGroup } from "src/components/Messages/LabelFlagGroup"; +import { LabelLikertGroup } from "src/components/Survey/LabelLikertGroup"; +import { LabelYesNoGroup } from "src/components/Messages/LabelYesNoGroup"; +import { getTypeSafei18nKey } from "src/lib/i18n"; import { Label } from "src/types/Tasks"; -import { LabelLikertGroup } from "../Survey/LabelLikertGroup"; -import { LabelFlagGroup } from "./LabelFlagGroup"; -import { LabelYesNoGroup } from "./LabelYesNoGroup"; - export interface LabelInputInstructions { yesNoInstruction: string; flagInstruction: string; @@ -28,6 +30,7 @@ export const LabelInputGroup = ({ instructions, onChange, }: LabelInputGroupProps) => { + const { t } = useTranslation("labelling"); const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null); const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null); const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null); @@ -52,7 +55,17 @@ export const LabelInputGroup = ({ )} {flagIndexes.length > 0 && ( - {instructions.flagInstruction} + + {instructions.flagInstruction} + + `${t(getTypeSafei18nKey(labels[idx].name))}: ${t( + getTypeSafei18nKey(`${labels[idx].name}.explanation`) + )}` + )} + />{" "} + values[idx])} labelNames={flagIndexes.map((idx) => labels[idx].name)} diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index e65885e2..63c90a33 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -23,8 +23,8 @@ import { LabelMessagePopup } from "src/components/Messages/LabelPopup"; import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton"; import { ReportPopup } from "src/components/Messages/ReportPopup"; import { post } from "src/lib/api"; +import { colors } from "src/styles/Theme/colors"; import { Message, MessageEmojis } from "src/types/Conversation"; -import { colors } from "styles/Theme/colors"; import useSWRMutation from "swr/mutation"; interface MessageTableEntryProps { @@ -66,7 +66,7 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE ), [borderColor, inlineAvatar, message.is_assistant] ); - const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight); + const highlightColor = useColorModeValue(colors.light.active, colors.dark.active); const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, { onSuccess: setEmojis, @@ -97,14 +97,16 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }} onClick={(e) => e.stopPropagation()} > - {Object.entries(emojiState.emojis).map(([emoji, count]) => ( - react(emoji, !emojiState.user_emojis.includes(emoji))} - /> - ))} + {Object.entries(emojiState.emojis) + .filter(([k]) => !k.startsWith("_")) + .map(([emoji, count]) => ( + react(emoji, !emojiState.user_emojis.includes(emoji))} + /> + ))} -