mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into 911_control_email_signin
This commit is contained in:
@@ -33,14 +33,6 @@ jobs:
|
||||
WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
|
||||
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
|
||||
WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }}
|
||||
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
|
||||
${{ secrets.DEV_WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
|
||||
WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY:
|
||||
${{ secrets.DEV_WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
|
||||
${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
|
||||
${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
|
||||
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
|
||||
S3_REGION: ${{ secrets.S3_REGION }}
|
||||
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
|
||||
@@ -49,11 +41,20 @@ jobs:
|
||||
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
|
||||
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
|
||||
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
|
||||
MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }}
|
||||
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
|
||||
STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }}
|
||||
STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }}
|
||||
STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }}
|
||||
STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }}
|
||||
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
|
||||
${{ secrets.WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
|
||||
WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY:
|
||||
${{ secrets.WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
|
||||
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
|
||||
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
|
||||
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
+1
-1
@@ -140,4 +140,4 @@ automatically deploy the built release to the dev machine.
|
||||
### Contribute a Dataset
|
||||
|
||||
See
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
|
||||
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/openassistant/datasets/README.md)
|
||||
|
||||
@@ -122,6 +122,9 @@
|
||||
TREE_MANAGER__MAX_CHILDREN_COUNT:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') |
|
||||
default('3', true) }}"
|
||||
MESSAGE_SIZE_LIMIT:
|
||||
"{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') |
|
||||
default('2000', true) }}"
|
||||
USER_STATS_INTERVAL_DAY:
|
||||
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') |
|
||||
default('5', true) }}"
|
||||
@@ -175,9 +178,9 @@
|
||||
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY') }}"
|
||||
CLOUDFLARE_CAPTCHA_SERCERT_KEY:
|
||||
CLOUDFLARE_CAPTCHA_SECRET_KEY:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY') }}"
|
||||
'WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY') }}"
|
||||
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
|
||||
"{{ lookup('ansible.builtin.env',
|
||||
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA') }}"
|
||||
|
||||
+3
-2
@@ -36,6 +36,7 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user_stats", "streak_days")
|
||||
op.drop_column("user_stats", "streak_last_day_date")
|
||||
op.drop_column("user", "streak_days")
|
||||
op.drop_column("user", "streak_last_day_date")
|
||||
op.drop_column("user", "last_activity_date")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
"""add tos_acceptance_date to user
|
||||
|
||||
Revision ID: 55361f323d12
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 00:22:08.280251
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55361f323d12"
|
||||
down_revision = "f60958968ff8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
|
||||
op.drop_column("user_stats", "streak_days")
|
||||
op.drop_column("user_stats", "streak_last_day_date")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
|
||||
)
|
||||
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
|
||||
op.drop_column("user", "tos_acceptance_date")
|
||||
# ### end Alembic commands ###
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
"""add won_prompt_lottery_date to mts
|
||||
|
||||
Revision ID: f60958968ff8
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 10:10:38.301707
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f60958968ff8"
|
||||
down_revision = "7b8f0011e0b0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message_tree_state", sa.Column("won_prompt_lottery_date", sa.DateTime(timezone=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_tree_state", "won_prompt_lottery_date")
|
||||
# ### end Alembic commands ###
|
||||
+30
@@ -0,0 +1,30 @@
|
||||
"""add skip bool & skip_reason to task
|
||||
|
||||
Revision ID: 9e7ec4a9e3f2
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 21:46:49.971052
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9e7ec4a9e3f2"
|
||||
down_revision = "55361f323d12"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("task", sa.Column("skipped", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("task", sa.Column("skip_reason", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("task", "skip_reason")
|
||||
op.drop_column("task", "skipped")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,59 @@
|
||||
"""add troll_stats
|
||||
|
||||
Revision ID: 4d7e0b0ebe84
|
||||
Revises: 9e7ec4a9e3f2
|
||||
Create Date: 2023-02-02 15:44:12.647260
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d7e0b0ebe84"
|
||||
down_revision = "9e7ec4a9e3f2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"troll_stats",
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("troll_score", sa.Integer(), nullable=False),
|
||||
sa.Column("rank", sa.Integer(), nullable=True),
|
||||
sa.Column("red_flags", sa.Integer(), nullable=False),
|
||||
sa.Column("upvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("downvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("spam_prompts", sa.Integer(), nullable=False),
|
||||
sa.Column("quality", sa.Float(), nullable=True),
|
||||
sa.Column("humor", sa.Float(), nullable=True),
|
||||
sa.Column("toxicity", sa.Float(), nullable=True),
|
||||
sa.Column("violence", sa.Float(), nullable=True),
|
||||
sa.Column("helpfulness", sa.Float(), nullable=True),
|
||||
sa.Column("spam", sa.Integer(), nullable=False),
|
||||
sa.Column("lang_mismach", sa.Integer(), nullable=False),
|
||||
sa.Column("not_appropriate", sa.Integer(), nullable=False),
|
||||
sa.Column("pii", sa.Integer(), nullable=False),
|
||||
sa.Column("hate_speech", sa.Integer(), nullable=False),
|
||||
sa.Column("sexual_content", sa.Integer(), nullable=False),
|
||||
sa.Column("political_content", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "time_frame"),
|
||||
)
|
||||
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
|
||||
op.drop_table("troll_stats")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Add Account table
|
||||
|
||||
Revision ID: 8c8241d1f973
|
||||
Revises: 4d7e0b0ebe84
|
||||
Create Date: 2023-01-30 15:10:58.776315
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8c8241d1f973"
|
||||
down_revision = "4d7e0b0ebe84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"account",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("provider", "account", ["provider_account_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("provider", table_name="account")
|
||||
op.drop_table("account")
|
||||
# ### end Alembic commands ###
|
||||
@@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
ur = UserRepository(db=session, api_client=api_client)
|
||||
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
|
||||
pr = PromptRepository(
|
||||
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
|
||||
stats,
|
||||
tasks,
|
||||
text_labels,
|
||||
trollboards,
|
||||
users,
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
|
||||
@@ -59,6 +59,37 @@ def query_frontend_user(
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol.FrontEndUser)
|
||||
def create_frontend_user(
|
||||
*,
|
||||
create_user: protocol.CreateFrontendUserRequest,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
user = ur.lookup_client_user(create_user, create_missing=True)
|
||||
|
||||
def changed(a, b) -> bool:
|
||||
return a is not None and a != b
|
||||
|
||||
# only call update_user if something changed
|
||||
if (
|
||||
changed(create_user.enabled, user.enabled)
|
||||
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
|
||||
or changed(create_user.notes, user.notes)
|
||||
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
|
||||
):
|
||||
user = ur.update_user(
|
||||
user.id,
|
||||
enabled=create_user.enabled,
|
||||
show_on_leaderboard=create_user.show_on_leaderboard,
|
||||
tos_acceptance=create_user.tos_acceptance,
|
||||
notes=create_user.notes,
|
||||
)
|
||||
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
auth_method: str,
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from oasst_backend import auth
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/discord")
|
||||
def login_discord(request: Request):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
|
||||
raise HTTPException(status_code=302, headers={"location": auth_url})
|
||||
|
||||
|
||||
@router.get("/callback/discord", response_model=protocol_schema.Token)
|
||||
async def callback_discord(
|
||||
auth_code: str,
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
# Exchange the auth code for a Discord access token
|
||||
async with session.post(
|
||||
"https://discord.com/api/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
|
||||
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": "identify",
|
||||
},
|
||||
) as token_response:
|
||||
token_response_json = await token_response.json()
|
||||
access_token = token_response_json["access_token"]
|
||||
|
||||
# Retrieve user's Discord information using access token
|
||||
async with session.get(
|
||||
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
|
||||
) as user_response:
|
||||
user_response_json = await user_response.json()
|
||||
discord_id = user_response_json["id"]
|
||||
|
||||
account: Account = auth.get_account_from_discord_id(db, discord_id)
|
||||
|
||||
if not account:
|
||||
# Discord account is not linked to an OA account
|
||||
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
|
||||
|
||||
# Discord account is valid and linked to an OA account -> create JWT
|
||||
access_token = auth.create_access_token(account)
|
||||
|
||||
return protocol_schema.Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
def get_callback_uri(request: Request):
|
||||
"""
|
||||
Gets the URI for the base callback endpoint with no provider name appended.
|
||||
"""
|
||||
# This seems ugly, not sure if there is a better way
|
||||
current_url = str(request.url)
|
||||
domain = current_url.split("/api/v1/")[0]
|
||||
redirect_uri = f"{domain}/api/v1/callback"
|
||||
return redirect_uri
|
||||
@@ -131,7 +131,7 @@ def tasks_acknowledge_failure(
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
pr.skip_task(task_id=task_id, reason=nack_request.reason)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import TrollboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{time_frame}", response_model=TrollboardStats)
|
||||
def get_trollboard(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> TrollboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_trollboard(time_frame, limit=max_count)
|
||||
@@ -191,6 +191,7 @@ def update_user(
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
@@ -198,7 +199,7 @@ def update_user(
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from jose import jwt
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""
|
||||
Create an encoded JSON Web Token (JWT) using the given data.
|
||||
"""
|
||||
|
||||
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
|
||||
"""
|
||||
Get the Open-Assistant Account associated with the given Discord ID.
|
||||
"""
|
||||
|
||||
account: Account = (
|
||||
db.query(Account)
|
||||
.filter(
|
||||
Account.provider == "discord",
|
||||
Account.provider_account_id == discord_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return account
|
||||
@@ -13,7 +13,7 @@ class TreeManagerConfiguration(BaseModel):
|
||||
No new initial prompt tasks are handed out to users if this
|
||||
number is reached."""
|
||||
|
||||
max_tree_depth: int = 6
|
||||
max_tree_depth: int = 3
|
||||
"""Maximum depth of message tree."""
|
||||
|
||||
max_children_count: int = 3
|
||||
@@ -22,7 +22,7 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_prompter_replies: int = 1
|
||||
"""Number of prompter replies to collect per assistant reply."""
|
||||
|
||||
goal_tree_size: int = 15
|
||||
goal_tree_size: int = 12
|
||||
"""Total number of messages to gather per tree."""
|
||||
|
||||
num_reviews_initial_prompt: int = 3
|
||||
@@ -135,6 +135,11 @@ class Settings(BaseSettings):
|
||||
AUTH_LENGTH: int = 32
|
||||
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
|
||||
AUTH_COOKIE_NAME: str = "next-auth.session-token"
|
||||
AUTH_ALGORITHM: str = "HS256"
|
||||
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
AUTH_DISCORD_CLIENT_ID: str = ""
|
||||
AUTH_DISCORD_CLIENT_SECRET: str = ""
|
||||
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
@@ -158,6 +163,9 @@ class Settings(BaseSettings):
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
|
||||
True # TODO: set False after ToS acceptance UI was added to web-frontend
|
||||
)
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity
|
||||
from .message_tree_state import MessageTreeState
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .troll_stats import TrollStats
|
||||
from .user import User
|
||||
from .user_stats import UserStats, UserStatsTimeFrame
|
||||
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
"MessageEmoji",
|
||||
"TrollStats",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -46,6 +48,9 @@ class State(str, Enum):
|
||||
BACKLOG_RANKING = "backlog_ranking"
|
||||
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
|
||||
|
||||
PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting"
|
||||
"""Initial prompt has passed spam check, waiting to be drawn to grow."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -63,6 +68,7 @@ TERMINAL_STATES = (
|
||||
State.SCORING_FAILED,
|
||||
State.HALTED_BY_MODERATOR,
|
||||
State.BACKLOG_RANKING,
|
||||
State.PROMPT_LOTTERY_WAITING,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,3 +84,4 @@ class MessageTreeState(SQLModel, table=True):
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
@@ -31,6 +31,8 @@ class Task(SQLModel, table=True):
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
ack: Optional[bool] = None
|
||||
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
skipped: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
skip_reason: str = Field(nullable=True, max_length=512)
|
||||
frontend_message_id: Optional[str] = None
|
||||
message_tree_id: Optional[UUID] = None
|
||||
parent_message_id: Optional[UUID] = None
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class TrollStats(SQLModel, table=True):
|
||||
__tablename__ = "troll_stats"
|
||||
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
|
||||
|
||||
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
|
||||
)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
troll_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: float = Field(nullable=True)
|
||||
humor: float = Field(nullable=True)
|
||||
toxicity: float = Field(nullable=True)
|
||||
violence: float = Field(nullable=True)
|
||||
helpfulness: float = Field(nullable=True)
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
def compute_troll_score(self) -> int:
|
||||
return (
|
||||
self.red_flags * 3
|
||||
- self.upvotes
|
||||
+ self.downvotes
|
||||
+ self.spam_prompts
|
||||
+ self.lang_mismach
|
||||
+ self.not_appropriate
|
||||
+ self.pii
|
||||
+ self.hate_speech
|
||||
+ self.sexual_content
|
||||
+ self.political_content
|
||||
)
|
||||
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
# terms of service acceptance date
|
||||
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
def to_protocol_frontend_user(self):
|
||||
return protocol.FrontEndUser(
|
||||
user_id=self.id,
|
||||
@@ -55,4 +58,19 @@ class User(SQLModel, table=True):
|
||||
streak_days=self.streak_days,
|
||||
streak_last_day_date=self.streak_last_day_date,
|
||||
last_activity_date=self.last_activity_date,
|
||||
tos_acceptance_date=self.tos_acceptance_date,
|
||||
)
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
__tablename__ = "account"
|
||||
__table_args__ = (Index("provider", "provider_account_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
user_id: UUID = Field(foreign_key="user.id")
|
||||
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
|
||||
provider_account_id: str = Field(nullable=False, max_length=128)
|
||||
|
||||
@@ -35,7 +35,21 @@ from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
_task_type_and_reaction = (
|
||||
(
|
||||
(db_payload.PrompterReplyPayload, db_payload.AssistantReplyPayload),
|
||||
protocol_schema.EmojiCode.skip_reply,
|
||||
),
|
||||
(
|
||||
(db_payload.LabelInitialPromptPayload, db_payload.LabelConversationReplyPayload),
|
||||
protocol_schema.EmojiCode.skip_labeling,
|
||||
),
|
||||
(
|
||||
(db_payload.RankInitialPromptsPayload, db_payload.RankConversationRepliesPayload),
|
||||
protocol_schema.EmojiCode.skip_ranking,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -77,7 +91,14 @@ class PromptRepository:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
if self.user.deleted or not self.user.enabled:
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
|
||||
raise OasstError(
|
||||
"User has not accepted terms of service.",
|
||||
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
|
||||
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
)
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
@@ -90,7 +111,7 @@ class PromptRepository:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.",
|
||||
OasstErrorCode.MESSAGE_NOT_FOUND,
|
||||
HTTP_404_NOT_FOUND,
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
return message
|
||||
|
||||
@@ -139,7 +160,12 @@ class PromptRepository:
|
||||
return message
|
||||
|
||||
def _validate_task(
|
||||
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
|
||||
self,
|
||||
task: Task,
|
||||
*,
|
||||
task_id: Optional[UUID] = None,
|
||||
frontend_message_id: Optional[str] = None,
|
||||
check_ack: bool = True,
|
||||
) -> Task:
|
||||
if task is None:
|
||||
if task_id:
|
||||
@@ -150,7 +176,7 @@ class PromptRepository:
|
||||
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
if check_ack and not task.ack:
|
||||
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
@@ -675,7 +701,7 @@ class PromptRepository:
|
||||
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
|
||||
@@ -874,7 +900,7 @@ class PromptRepository:
|
||||
|
||||
if api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
|
||||
|
||||
qry = self.db.query(Message)
|
||||
if user_id:
|
||||
@@ -995,7 +1021,31 @@ WHERE message.id = cc.id;
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
|
||||
@managed_tx_method()
|
||||
def skip_task(self, task_id: UUID, reason: str):
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
task = self.task_repository.fetch_task_by_id(task_id)
|
||||
self._validate_task(task, check_ack=False)
|
||||
|
||||
if not task.collective:
|
||||
task.skipped = True
|
||||
task.skip_reason = reason
|
||||
self.db.add(task)
|
||||
|
||||
def handle_cancel_emoji(task_payload: db_payload.TaskPayload) -> Message | None:
|
||||
for types, emoji in _task_type_and_reaction:
|
||||
for t in types:
|
||||
if isinstance(task_payload, t):
|
||||
return self.handle_message_emoji(task.parent_message_id, protocol_schema.EmojiOp.add, emoji)
|
||||
return None
|
||||
|
||||
task_payload: db_payload.TaskPayload = task.payload.payload
|
||||
handle_cancel_emoji(task_payload)
|
||||
|
||||
def handle_message_emoji(
|
||||
self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema.EmojiCode
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
message = self.fetch_message(message_id)
|
||||
|
||||
@@ -167,21 +167,6 @@ class TaskRepository:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def insert_task(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,16 @@ from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.models import (
|
||||
Message,
|
||||
MessageEmoji,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
User,
|
||||
message_tree_state,
|
||||
)
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils import tree_export
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
@@ -21,7 +30,8 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, or_, text, update
|
||||
from oasst_shared.utils import utcnow
|
||||
from sqlmodel import Session, and_, func, not_, or_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -153,7 +163,7 @@ class TreeManager:
|
||||
|
||||
def _determine_task_availability_internal(
|
||||
self,
|
||||
num_active_trees: int,
|
||||
num_missing_prompts: int,
|
||||
extendible_parents: list[ExtendibleParentRow],
|
||||
prompts_need_review: list[Message],
|
||||
replies_need_review: list[Message],
|
||||
@@ -161,8 +171,7 @@ class TreeManager:
|
||||
) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
|
||||
|
||||
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(1, num_missing_prompts)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
@@ -194,6 +203,72 @@ class TreeManager:
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def _prompt_lottery(self, lang: str) -> int:
|
||||
MAX_RETRIES = 5
|
||||
retry = 0
|
||||
while True:
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
num_missing_prompts = self.cfg.max_active_trees - num_active_trees
|
||||
if num_missing_prompts <= 0:
|
||||
return 0
|
||||
|
||||
# select among distinct users
|
||||
authors_qry = (
|
||||
self.db.query(Message.user_id)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.distinct(Message.user_id)
|
||||
)
|
||||
|
||||
author_ids = authors_qry.all()
|
||||
if len(author_ids) == 0:
|
||||
logger.info(
|
||||
f"No prompts for prompt lottery available ({num_missing_prompts} trees missing for {lang=})."
|
||||
)
|
||||
return num_missing_prompts
|
||||
|
||||
# first select an authour
|
||||
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
|
||||
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
|
||||
|
||||
# select random prompt of author
|
||||
qry = (
|
||||
self.db.query(MessageTreeState, Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(
|
||||
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
|
||||
Message.user_id == prompt_author_id,
|
||||
Message.lang == lang,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
)
|
||||
.limit(100)
|
||||
)
|
||||
|
||||
prompt_candidates = qry.all()
|
||||
if len(prompt_candidates) == 0:
|
||||
retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
|
||||
if retry < MAX_RETRIES:
|
||||
continue
|
||||
else:
|
||||
logger.warning("Max retries in prompt lottery reached.")
|
||||
return num_missing_prompts
|
||||
|
||||
winner_prompt = random.choice(prompt_candidates)
|
||||
message: Message = winner_prompt.Message
|
||||
logger.info(f"Prompt lottery winner: {message.id=}")
|
||||
|
||||
mts: MessageTreeState = winner_prompt.MessageTreeState
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
self.db.flush()
|
||||
|
||||
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
@@ -201,14 +276,14 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang)
|
||||
extendible_parents, _ = self.query_extendible_parents(lang=lang)
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
@@ -238,7 +313,8 @@ class TreeManager:
|
||||
lang = "en"
|
||||
logger.warning("Task request without lang tag received, assuming 'en'.")
|
||||
|
||||
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
|
||||
num_missing_prompts = self._prompt_lottery(lang=lang)
|
||||
|
||||
prompts_need_review = self.query_prompts_need_review(lang=lang)
|
||||
replies_need_review = self.query_replies_need_review(lang=lang)
|
||||
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
|
||||
@@ -256,7 +332,7 @@ class TreeManager:
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
@@ -268,7 +344,7 @@ class TreeManager:
|
||||
)
|
||||
else:
|
||||
task_count_by_type = self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
num_missing_prompts=num_missing_prompts,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
@@ -611,7 +687,7 @@ class TreeManager:
|
||||
)
|
||||
else:
|
||||
self.enter_low_grade_state(msg.message_tree_id)
|
||||
self.check_condition_for_growing_state(msg.message_tree_id)
|
||||
self.check_condition_for_prompt_lottery(msg.message_tree_id)
|
||||
elif msg.review_count >= self.cfg.num_reviews_reply:
|
||||
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
||||
msg.review_result = True
|
||||
@@ -649,6 +725,8 @@ class TreeManager:
|
||||
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
else:
|
||||
if mts.state == message_tree_state.State.GROWING and mts.won_prompt_lottery_date is None:
|
||||
mts.won_prompt_lottery_date = utcnow()
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
|
||||
@@ -656,8 +734,8 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
def check_condition_for_prompt_lottery(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_prompt_lottery({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
|
||||
@@ -670,7 +748,7 @@ class TreeManager:
|
||||
logger.debug(f"False {initial_prompt.review_result=}")
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
self._enter_state(mts, message_tree_state.State.PROMPT_LOTTERY_WAITING)
|
||||
return True
|
||||
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
@@ -820,6 +898,14 @@ class TreeManager:
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.outerjoin(
|
||||
MessageEmoji,
|
||||
and_(
|
||||
Message.id == MessageEmoji.message_id,
|
||||
MessageEmoji.user_id == self.pr.user_id,
|
||||
MessageEmoji.emoji == protocol_schema.EmojiCode.skip_labeling,
|
||||
),
|
||||
)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == state,
|
||||
@@ -827,6 +913,7 @@ class TreeManager:
|
||||
not_(Message.deleted),
|
||||
Message.review_count < required_reviews,
|
||||
Message.lang == lang,
|
||||
MessageEmoji.message_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -879,12 +966,18 @@ SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) chi
|
||||
mts.message_tree_id
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
LEFT JOIN message_emoji me on
|
||||
(m.parent_id = me.message_id
|
||||
AND :skip_user_id IS NOT NULL
|
||||
AND me.user_id = :skip_user_id
|
||||
AND me.emoji = :skip_ranking)
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :ranking_state -- message tree must be in ranking state
|
||||
AND m.review_result -- must be reviewed
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
AND me.message_id IS NULL -- no skip ranking emoji for user
|
||||
GROUP BY m.parent_id, m.role, mts.message_tree_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
@@ -910,6 +1003,8 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"skip_user_id": self.pr.user_id,
|
||||
"skip_ranking": protocol_schema.EmojiCode.skip_ranking,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
@@ -919,6 +1014,11 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message_emoji me ON
|
||||
(m.id = me.message_id
|
||||
AND :skip_user_id IS NOT NULL
|
||||
AND me.user_id = :skip_user_id
|
||||
AND me.emoji = :skip_reply)
|
||||
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
@@ -926,6 +1026,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND m.lang = :lang -- parent matches lang
|
||||
AND me.message_id IS NULL -- no skip reply emoji for user
|
||||
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
@@ -946,6 +1047,8 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"num_prompter_replies": self.cfg.num_prompter_replies,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"skip_user_id": self.pr.user_id,
|
||||
"skip_reply": protocol_schema.EmojiCode.skip_reply,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -984,6 +1087,8 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
"num_prompter_replies": self.cfg.num_prompter_replies,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
"skip_user_id": self.pr.user_id,
|
||||
"skip_reply": protocol_schema.EmojiCode.skip_reply,
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
@@ -1097,7 +1202,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state."
|
||||
)
|
||||
for t in prompt_review_trees:
|
||||
self.check_condition_for_growing_state(t.message_tree_id)
|
||||
self.check_condition_for_prompt_lottery(t.message_tree_id)
|
||||
|
||||
growing_trees: list[MessageTreeState] = (
|
||||
self.db.query(MessageTreeState)
|
||||
@@ -1288,11 +1393,17 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
def _reactivate_tree(self, mts: MessageTreeState):
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
if mts.state == message_tree_state.State.PROMPT_LOTTERY_WAITING:
|
||||
return
|
||||
|
||||
tree_id = mts.message_tree_id
|
||||
if self.check_condition_for_growing_state(tree_id):
|
||||
if mts.won_prompt_lottery_date is not None:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
if self.check_condition_for_ranking_state(tree_id):
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
else:
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_prompt_lottery(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(
|
||||
@@ -1450,7 +1561,8 @@ if __name__ == "__main__":
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
# dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
dummy_user = protocol_schema.User(id="1234", display_name="bulb", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
cfg = TreeManagerConfiguration()
|
||||
@@ -1465,14 +1577,18 @@ if __name__ == "__main__":
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
xs = tm.query_prompts_need_review(lang="en")
|
||||
print("xs", len(xs))
|
||||
for x in xs:
|
||||
print(x.id, x.emojis)
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review(lang="en"))
|
||||
# print("query_extendible_trees", tm.query_extendible_trees())
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
)
|
||||
# print(
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
# )
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
@@ -73,7 +73,8 @@ class UserRepository:
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
) -> None:
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
) -> User:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
@@ -94,8 +95,11 @@ class UserRepository:
|
||||
user.notes = notes
|
||||
if show_on_leaderboard is not None:
|
||||
user.show_on_leaderboard = show_on_leaderboard
|
||||
if tos_acceptance:
|
||||
user.tos_acceptance_date = utcnow()
|
||||
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
@@ -143,8 +147,10 @@ class UserRepository:
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
if auth_method == "system":
|
||||
user.show_on_leaderboard = False # don't show system users, e.g. import user
|
||||
user.tos_acceptance_date = utcnow()
|
||||
self.db.add(user)
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
@@ -156,6 +162,10 @@ class UserRepository:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
if not (client_user.auth_method and client_user.id):
|
||||
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
|
||||
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
|
||||
@@ -5,19 +5,31 @@ from uuid import UUID
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models import (
|
||||
Message,
|
||||
MessageReaction,
|
||||
MessageTreeState,
|
||||
Task,
|
||||
TextLabels,
|
||||
TrollStats,
|
||||
User,
|
||||
UserStats,
|
||||
UserStatsTimeFrame,
|
||||
)
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
LabelPrompterReplyPayload,
|
||||
RankingReactionPayload,
|
||||
)
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore
|
||||
from oasst_shared.utils import log_timing, utcnow
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql.functions import coalesce
|
||||
from sqlmodel import Session, delete, func, text
|
||||
|
||||
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
|
||||
if r["UserStats"]:
|
||||
d = r["UserStats"].dict()
|
||||
else:
|
||||
@@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
|
||||
return UserScore(**d)
|
||||
|
||||
|
||||
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
|
||||
if r["TrollStats"]:
|
||||
d = r["TrollStats"].dict()
|
||||
else:
|
||||
d = {"modified_date": utcnow()}
|
||||
for k in [
|
||||
"user_id",
|
||||
"username",
|
||||
"auth_method",
|
||||
"display_name",
|
||||
"last_activity_date",
|
||||
]:
|
||||
d[k] = r[k]
|
||||
if highlighted_user_id:
|
||||
d["highlighted"] = r["user_id"] == highlighted_user_id
|
||||
return TrollScore(**d)
|
||||
|
||||
|
||||
class UserStatsRepository:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
@@ -133,6 +163,38 @@ class UserStatsRepository:
|
||||
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
|
||||
return stats_by_timeframe
|
||||
|
||||
def get_trollboard(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
limit: int = 100,
|
||||
highlighted_user_id: Optional[UUID] = None,
|
||||
) -> TrollboardStats:
|
||||
"""
|
||||
Get trollboard stats for the specified time frame
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.session.query(
|
||||
User.id.label("user_id"),
|
||||
User.username,
|
||||
User.auth_method,
|
||||
User.display_name,
|
||||
User.last_activity_date,
|
||||
TrollStats,
|
||||
)
|
||||
.join(TrollStats, User.id == TrollStats.user_id)
|
||||
.filter(TrollStats.time_frame == time_frame.value)
|
||||
.order_by(TrollStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
|
||||
if len(trollboard) > 0:
|
||||
last_update = max(x.modified_date for x in trollboard)
|
||||
else:
|
||||
last_update = utcnow()
|
||||
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
|
||||
|
||||
def query_total_prompts_per_user(
|
||||
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
|
||||
):
|
||||
@@ -292,10 +354,145 @@ class UserStatsRepository:
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_ranks(time_frame=time_frame)
|
||||
self.update_leader_ranks(time_frame=time_frame)
|
||||
|
||||
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
|
||||
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
|
||||
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(Message.user_id, func.count().label("spam_prompts"))
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id)
|
||||
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
|
||||
qry = (
|
||||
self.session.query(
|
||||
Message.user_id,
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
|
||||
"lang_mismach"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
|
||||
"not_appropriate"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
|
||||
"sexual_content"
|
||||
),
|
||||
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
|
||||
"political_content"
|
||||
),
|
||||
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
|
||||
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
|
||||
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
|
||||
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
|
||||
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
|
||||
)
|
||||
.select_from(TextLabels)
|
||||
.join(Message, TextLabels.message_id == Message.id)
|
||||
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
|
||||
)
|
||||
|
||||
if reference_time:
|
||||
qry = qry.filter(Message.created_date >= reference_time)
|
||||
|
||||
qry = qry.group_by(Message.user_id)
|
||||
return qry
|
||||
|
||||
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
|
||||
# gather user data
|
||||
|
||||
time_frame_key = time_frame.value
|
||||
|
||||
stats_by_user: dict[UUID, TrollStats] = dict()
|
||||
now = utcnow()
|
||||
|
||||
def get_stats(id: UUID) -> TrollStats:
|
||||
us = stats_by_user.get(id)
|
||||
if not us:
|
||||
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
|
||||
stats_by_user[id] = us
|
||||
return us
|
||||
|
||||
# emoji counts of user's messages
|
||||
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
s.upvotes = r["up"]
|
||||
s.downvotes = r["down"]
|
||||
s.red_flags = r["flag"]
|
||||
|
||||
# num spam prompts
|
||||
qry = self.query_spam_prompts_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid, count = r
|
||||
s = get_stats(uid).spam_prompts = count
|
||||
|
||||
label_field_names = (
|
||||
"quality",
|
||||
"humor",
|
||||
"toxicity",
|
||||
"violence",
|
||||
"helpfulness",
|
||||
"spam",
|
||||
"lang_mismach",
|
||||
"not_appropriate",
|
||||
"pii",
|
||||
"hate_speech",
|
||||
"sexual_content",
|
||||
"political_content",
|
||||
)
|
||||
|
||||
# label counts / mean values
|
||||
qry = self.query_labels_per_user(reference_time=base_date)
|
||||
for r in qry:
|
||||
uid = r["user_id"]
|
||||
s = get_stats(uid)
|
||||
for fn in label_field_names:
|
||||
setattr(s, fn, r[fn])
|
||||
|
||||
# delete all existing stast for time frame
|
||||
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.troll_score = v.compute_troll_score()
|
||||
|
||||
# insert user objects
|
||||
self.session.add_all(stats_by_user.values())
|
||||
self.session.flush()
|
||||
|
||||
self.update_troll_ranks(time_frame=time_frame)
|
||||
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
"""
|
||||
Update user_stats ranks. The persisted rank values allow to
|
||||
quickly the rank of a single user and to query nearby users.
|
||||
@@ -329,10 +526,41 @@ WHERE
|
||||
r = self.session.execute(
|
||||
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
|
||||
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
@log_timing(log_kwargs=True)
|
||||
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
|
||||
sql_update_troll_rank = """
|
||||
-- update rank
|
||||
UPDATE troll_stats ts
|
||||
SET "rank" = r."rank"
|
||||
FROM
|
||||
(SELECT
|
||||
ROW_NUMBER () OVER(
|
||||
PARTITION BY time_frame
|
||||
ORDER BY troll_score DESC, user_id
|
||||
) AS "rank", user_id, time_frame
|
||||
FROM troll_stats ts2
|
||||
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
|
||||
WHERE
|
||||
ts.user_id = r.user_id
|
||||
AND ts.time_frame = r.time_frame;"""
|
||||
r = self.session.execute(
|
||||
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
|
||||
)
|
||||
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
|
||||
|
||||
def update_stats_time_frame(
|
||||
self,
|
||||
time_frame: UserStatsTimeFrame,
|
||||
reference_time: Optional[datetime] = None,
|
||||
leader_stats: bool = True,
|
||||
troll_stats: bool = True,
|
||||
):
|
||||
if leader_stats:
|
||||
self._update_stats_internal(time_frame, reference_time)
|
||||
if troll_stats:
|
||||
self._update_troll_stats_internal(time_frame, reference_time)
|
||||
self.session.commit()
|
||||
|
||||
@log_timing(log_kwargs=True, level="INFO")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
aiohttp==3.8.3
|
||||
alembic==1.8.1
|
||||
cryptography==39.0.0
|
||||
fastapi==0.88.0
|
||||
|
||||
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
OASST_API_KEY="1234"
|
||||
|
||||
+388
-391
@@ -9,27 +9,25 @@ import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
assistant_reply_messages,
|
||||
confirm_label_response_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
label_assistant_reply_message,
|
||||
label_initial_prompt_message,
|
||||
label_prompter_reply_message,
|
||||
initial_prompt_messages,
|
||||
label_assistant_reply_messages,
|
||||
label_prompter_reply_messages,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
prompter_reply_messages,
|
||||
rank_assistant_reply_message,
|
||||
rank_initial_prompts_message,
|
||||
rank_prompter_reply_message,
|
||||
rank_conversation_reply_messages,
|
||||
rank_initial_prompts_messages,
|
||||
rank_prompter_reply_messages,
|
||||
task_complete_embed,
|
||||
)
|
||||
from bot.settings import Settings
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("WorkPlugin")
|
||||
|
||||
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
|
||||
|
||||
settings = Settings()
|
||||
|
||||
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
|
||||
|
||||
|
||||
class _TaskHandler(t.Generic[_Task_contra]):
|
||||
"""Handle user interaction for a task."""
|
||||
|
||||
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
|
||||
"""Create a new `TaskHandler`.
|
||||
|
||||
Args:
|
||||
ctx (lightbulb.Context): The context of the command that started the task.
|
||||
task (_Task_contra): The task to handle.
|
||||
"""
|
||||
self.ctx = ctx
|
||||
self.task = task
|
||||
self.task_messages = self.get_task_messages(task)
|
||||
self.sent_messages: list[hikari.Message] = []
|
||||
|
||||
@staticmethod
|
||||
def get_task_messages(task: _Task_contra) -> list[str]:
|
||||
"""Get the messages to send to the user for the task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
|
||||
"""Send the task and wait for the user to accept/skip/cancel it."""
|
||||
# Send all but the last message because we need to attach buttons to the last one
|
||||
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
|
||||
for task_msg in self.task_messages[:-1]:
|
||||
if len(task_msg) > 2000:
|
||||
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
|
||||
task_msg = task_msg[:1999]
|
||||
self.sent_messages.append(await self.ctx.author.send(task_msg))
|
||||
|
||||
# Send the last message with buttons
|
||||
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
|
||||
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
|
||||
|
||||
await task_accept_view.start(last_msg)
|
||||
await task_accept_view.wait()
|
||||
|
||||
return task_accept_view.choice
|
||||
|
||||
async def handle(self) -> None:
|
||||
"""Handle the user's response to the task.
|
||||
|
||||
This method should be called after `send` has been called."""
|
||||
# Ack task to the backend
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
|
||||
|
||||
# Loop until the user's input is accepted
|
||||
while True:
|
||||
try:
|
||||
# Wait for user to send a message
|
||||
event = await self.ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
predicate=lambda e: (
|
||||
e.author_id == self.ctx.author.id
|
||||
and e.message.content is not None
|
||||
and not e.message.content.startswith(settings.prefix)
|
||||
),
|
||||
timeout=MAX_TASK_TIME,
|
||||
)
|
||||
|
||||
# Validate the message
|
||||
if event.content is None or not self.check_user_input(event.content):
|
||||
await self.ctx.author.send("Invalid input")
|
||||
continue
|
||||
|
||||
# Confirm user input
|
||||
if not (await self.confirm_user_input(event.content)):
|
||||
continue
|
||||
|
||||
# Message is valid and confirmed by user
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
next_task = await self.notify(event.content, event)
|
||||
if not isinstance(next_task, protocol_schema.TaskDone):
|
||||
raise TypeError(f"Unknown task type: {next_task!r}")
|
||||
|
||||
return
|
||||
|
||||
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||
"""Notify the backend that the user completed the task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
"""Check the user's response to the task. Returns True if the response is valid."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def cancel(self, reason: str = "not specified") -> None:
|
||||
"""Cancel the task."""
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
await oasst_api.nack_task(self.task.id, reason)
|
||||
|
||||
|
||||
_Ranking_contra = t.TypeVar(
|
||||
"_Ranking_contra",
|
||||
bound=protocol_schema.RankAssistantRepliesTask
|
||||
| protocol_schema.RankInitialPromptsTask
|
||||
| protocol_schema.RankPrompterRepliesTask
|
||||
| protocol_schema.RankConversationRepliesTask,
|
||||
contravariant=True,
|
||||
)
|
||||
|
||||
|
||||
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
|
||||
"""This should not be used directly. Use its subclasses instead."""
|
||||
|
||||
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||
|
||||
task = await oasst_api.post_interaction(
|
||||
protocol_schema.MessageRanking(
|
||||
user=protocol_schema.User(
|
||||
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
|
||||
),
|
||||
ranking=[int(r) - 1 for r in content.split(",")],
|
||||
message_id=f"{self.sent_messages[0].id}",
|
||||
)
|
||||
)
|
||||
|
||||
db: Connection = self.ctx.bot.d.db
|
||||
async with db.cursor() as cursor:
|
||||
row = await (
|
||||
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
|
||||
).fetchone()
|
||||
log_channel = row[0] if row else None
|
||||
log_messages: list[hikari.Message] = []
|
||||
|
||||
if log_channel is not None:
|
||||
for message in self.task_messages[:-1]:
|
||||
msg = await self.ctx.bot.rest.create_message(log_channel, message)
|
||||
log_messages.append(msg)
|
||||
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||
return rank_assistant_reply_message(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
|
||||
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
|
||||
super().__init__(ctx, task)
|
||||
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||
return rank_initial_prompts_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.prompt_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||
return rank_prompter_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||
return rank_conversation_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(
|
||||
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||
)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||
return initial_prompt_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
return prompter_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||
return assistant_reply_messages(task)
|
||||
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
return len(content) > 0
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
|
||||
|
||||
|
||||
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
|
||||
def check_user_input(self, content: str) -> bool:
|
||||
user_labels = content.split(",")
|
||||
return (
|
||||
all([l in self.task.valid_labels for l in user_labels])
|
||||
and self.task.mandatory_labels is not None
|
||||
and all([m in user_labels for m in self.task.mandatory_labels])
|
||||
)
|
||||
|
||||
async def confirm_user_input(self, content: str) -> bool:
|
||||
confirm_input_view = YesNoView()
|
||||
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
|
||||
await confirm_input_view.start(msg)
|
||||
await confirm_input_view.wait()
|
||||
|
||||
return bool(confirm_input_view.choice)
|
||||
|
||||
|
||||
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||
return label_assistant_reply_messages(task)
|
||||
|
||||
|
||||
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
|
||||
@staticmethod
|
||||
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||
return label_prompter_reply_messages(task)
|
||||
|
||||
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=str(TaskRequestType.random),
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# Only send this message if started from a server
|
||||
if ctx.guild_id is not None:
|
||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
||||
currently_working: dict[
|
||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
||||
] = ctx.bot.d.currently_working
|
||||
|
||||
async def work2(ctx: lightbulb.Context) -> None:
|
||||
"""Complete a task."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
|
||||
|
||||
# Check if the user is already working on a task
|
||||
if ctx.author.id in currently_working:
|
||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
|
||||
case False | None:
|
||||
return
|
||||
case True:
|
||||
old_msg, task_id = currently_working[ctx.author.id]
|
||||
if old_msg is not None:
|
||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
||||
map(lambda c: c, old_msg.components)
|
||||
await old_msg.delete()
|
||||
if task_id is not None:
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
task_id = currently_working[ctx.author.id]
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
|
||||
await msg.delete()
|
||||
if ctx.guild_id:
|
||||
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
currently_working[ctx.author.id] = (None, None)
|
||||
|
||||
# Create a TaskRequestType from the stringified enum value
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
# Keep sending tasks until the user doesn't want more
|
||||
try:
|
||||
await _handle_task(ctx, task_type)
|
||||
while True:
|
||||
task = await oasst_api.fetch_random_task(
|
||||
user=protocol_schema.User(
|
||||
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
|
||||
),
|
||||
)
|
||||
|
||||
# Ranking tasks
|
||||
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
|
||||
task_handler = RankAssistantRepliesHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
task_handler = RankInitialPromptHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
|
||||
task_handler = RankPrompterReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
task_handler = RankConversationReplyHandler(ctx, task)
|
||||
|
||||
# Text input tasks
|
||||
elif isinstance(task, protocol_schema.InitialPromptTask):
|
||||
task_handler = InitialPromptHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.PrompterReplyTask):
|
||||
task_handler = PrompterReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.AssistantReplyTask):
|
||||
task_handler = AssistantReplyHandler(ctx, task)
|
||||
|
||||
# Label tasks
|
||||
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
|
||||
task_handler = LabelAssistantReplyHandler(ctx, task)
|
||||
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
|
||||
task_handler = LabelPrompterReplyHandler(ctx, task)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown task type: {type(task)}")
|
||||
|
||||
resp = await task_handler.send()
|
||||
|
||||
match resp:
|
||||
case "accept":
|
||||
currently_working[ctx.author.id] = task.id
|
||||
await task_handler.handle()
|
||||
case "next":
|
||||
await task_handler.cancel("user skipped task")
|
||||
case "cancel":
|
||||
await task_handler.cancel("user canceled work")
|
||||
break
|
||||
case None:
|
||||
await task_handler.cancel("select timed out")
|
||||
break
|
||||
finally:
|
||||
del currently_working[ctx.author.id]
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
If they select one, present the task steps until a `task_done` task is received.
|
||||
Finally, ask the user if they want to perform another task (of the same type).
|
||||
"""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
|
||||
# Continue to complete tasks until the user doesn't want to do another
|
||||
done = False
|
||||
while not done:
|
||||
|
||||
# Loop until the user accepts a task
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
# User cancelled
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send(embed=plain_embed("Please type your response below:"))
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id
|
||||
and not (e.message.content or "").startswith(settings.prefix),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
valid, err_msg = _validate_user_input(event.content, task)
|
||||
if not valid or event.content is None:
|
||||
|
||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Confirm user input
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
content = confirm_label_response_message(event.content)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
logger.critical(f"Unknown task type: {task.type}")
|
||||
raise ValueError(f"Unknown task type: {task.type}")
|
||||
|
||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
||||
await confirm_resp_view.start(msg)
|
||||
await confirm_resp_view.wait()
|
||||
|
||||
match confirm_resp_view.choice:
|
||||
case False | None:
|
||||
continue
|
||||
case True:
|
||||
await msg.delete() # buttons are already gone
|
||||
|
||||
# Send the response to the backend
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
||||
reply = protocol_schema.MessageRanking(
|
||||
message_id=str(msg_id),
|
||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
||||
labels = event.content.replace(" ", "").split(",")
|
||||
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
|
||||
|
||||
reply = protocol_schema.TextLabels(
|
||||
message_id=task.message_id,
|
||||
labels=labels_dict,
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {task.type}")
|
||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
||||
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
new_task = await oasst_api.post_interaction(reply)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in all the log channels that the task is complete
|
||||
conn: Connection = ctx.bot.d.db
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
||||
log_channel_ids = await cursor.fetchall()
|
||||
|
||||
channels = [
|
||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
||||
|
||||
# ask the user if they want to do another task
|
||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
||||
await another_task_view.start(msg)
|
||||
await another_task_view.wait()
|
||||
|
||||
match another_task_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg = await _send_task(ctx, task, msg)
|
||||
msg_id = str(msg.id)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
await oasst_api.ack_task(task.id, msg_id)
|
||||
return task, msg_id
|
||||
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
"""
|
||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
content = initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
content = rank_initial_prompts_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
content = rank_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_initial_prompt:
|
||||
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
|
||||
logger.debug("sending label initial prompt task")
|
||||
content = label_initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_prompter_reply:
|
||||
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
|
||||
logger.debug("sending label prompter reply task")
|
||||
content = label_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.label_assistant_reply:
|
||||
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
|
||||
logger.debug("sending label assistant reply task")
|
||||
content = label_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
content = prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
content = assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"unknown task type {task.type}")
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
if not msg:
|
||||
msg = await ctx.author.send(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
else:
|
||||
await msg.edit(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
# Set the choice id as the current msg id
|
||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, msg
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
||||
if content is None:
|
||||
return False, "No input provided"
|
||||
|
||||
# User message input
|
||||
if (
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.prompter_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0, "Message must be at least one character long."
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
||||
"Message must contain numbers for all replies.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
# Labels tasks
|
||||
elif task.type in (
|
||||
TaskRequestType.label_initial_prompt,
|
||||
TaskRequestType.label_prompter_reply,
|
||||
TaskRequestType.label_assistant_reply,
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.LabelInitialPromptTask
|
||||
| protocol_schema.LabelPrompterReplyTask
|
||||
| protocol_schema.LabelAssistantReplyTask,
|
||||
)
|
||||
|
||||
labels = content.replace(" ", "").split(",")
|
||||
valid_labels = set(task.valid_labels)
|
||||
return (
|
||||
set(labels).issubset(valid_labels),
|
||||
"Message must only contain labels from predefined set of labels.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"Unknown task type {task.type}")
|
||||
raise ValueError(f"Unknown task type {task.type}")
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
"""View with three buttons: accept, next, and cancel.
|
||||
|
||||
|
||||
+129
-69
@@ -1,4 +1,11 @@
|
||||
"""All user-facing messages and embeds."""
|
||||
"""All user-facing messages and embeds.
|
||||
|
||||
When sending a conversation
|
||||
- The function will return a list of strings
|
||||
- use asyncio.gather to send all messages
|
||||
|
||||
-
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _label_prompt(text: str) -> str:
|
||||
return f":question: _{text}"
|
||||
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
|
||||
return f""":question: _{text}_
|
||||
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
|
||||
Valid labels: {", ".join(valid_labels)}
|
||||
"""
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
|
||||
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
|
||||
|
||||
def _ordered_list(items: list[str]) -> str:
|
||||
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
|
||||
return "\n\n".join(_make_ordered_list(items))
|
||||
|
||||
|
||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
|
||||
|
||||
def _hint(hint: str | None) -> str:
|
||||
return f"{NL}Hint: {hint}" if hint else ""
|
||||
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
|
||||
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
messages = map(
|
||||
lambda m: f"""\
|
||||
:robot: __Assistant__:
|
||||
{m.text}
|
||||
"""
|
||||
if m.is_assistant
|
||||
else f"""\
|
||||
:person_red_hair: __User__:
|
||||
{m.text}
|
||||
""",
|
||||
conv.messages,
|
||||
)
|
||||
return list(messages)
|
||||
|
||||
|
||||
def _li(text: str) -> str:
|
||||
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
|
||||
###
|
||||
|
||||
|
||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
||||
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("INITIAL PROMPT")}
|
||||
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
|
||||
|
||||
|
||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
||||
{_hint(task.hint)}
|
||||
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||
"""
|
||||
]
|
||||
|
||||
|
||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
||||
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("RANK INITIAL PROMPTS")}
|
||||
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
|
||||
|
||||
|
||||
{_ordered_list(task.prompts)}
|
||||
{_ordered_list(task.prompt_messages)}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
||||
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
|
||||
"""
|
||||
]
|
||||
|
||||
|
||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
||||
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_h1("RANK PROMPTER REPLIES")}
|
||||
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":person_red_hair: __User__:
|
||||
{_ordered_list(task.reply_messages)}
|
||||
|
||||
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
"""
|
||||
|
||||
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_h1("RANK ASSISTANT REPLIES")}
|
||||
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":robot: __Assistant__:,
|
||||
{_ordered_list(task.reply_messages)}
|
||||
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
|
||||
return [
|
||||
"""\
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
"""
|
||||
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f""":person_red_hair: __User__:
|
||||
{_ordered_list(task.reply_messages)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
|
||||
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
|
||||
|
||||
{task.prompt}
|
||||
|
||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
|
||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
"""
|
||||
|
||||
|
||||
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
|
||||
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("LABEL PROMPTER REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""{_user(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
|
||||
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{_h1("LABEL ASSISTANT REPLY")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""
|
||||
{_assistant(None)}
|
||||
{task.reply}
|
||||
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
||||
"""
|
||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
return [
|
||||
"""\
|
||||
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
|
||||
|
||||
{_h1("PROMPTER REPLY")}
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||
|
||||
:speech_balloon: _Please provide a reply to the assistant._
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_hint(task.hint)}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
"""
|
||||
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
# return [
|
||||
# message_templates.render("title.msg", "PROMPTER REPLY"),
|
||||
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
|
||||
# message_templates.render("prompter_reply_task.msg", task.hint),
|
||||
# ]
|
||||
|
||||
|
||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
||||
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||
return f"""\
|
||||
{_h1("ASSISTANT REPLY")}
|
||||
return [
|
||||
"""\
|
||||
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
|
||||
|
||||
""",
|
||||
*_conversation(task.conversation),
|
||||
"""\
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
|
||||
{_response_prompt("Please provide an assistant reply to the prompter.")}
|
||||
"""
|
||||
:speech_balloon: _Please provide a reply to the user as the assistant._
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def confirm_text_response_message(content: str) -> str:
|
||||
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
|
||||
"""
|
||||
|
||||
|
||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
|
||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||
original_list = _make_ordered_list(items)
|
||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||
|
||||
+1
-1
@@ -126,7 +126,7 @@ services:
|
||||
- NEXTAUTH_URL=http://localhost:3000
|
||||
- DEBUG_LOGIN=true
|
||||
- NEXT_PUBLIC_CLOUDFARE_CAPTCHA_SITE_KEY=1x00000000000000000000AA
|
||||
- CLOUDFLARE_CAPTCHA_SERCERT_KEY=1x0000000000000000000000000000000AA
|
||||
- CLOUDFLARE_CAPTCHA_SECRET_KEY=1x0000000000000000000000000000000AA
|
||||
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA=true
|
||||
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN=true
|
||||
depends_on:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Guides
|
||||
|
||||
Useful guides.
|
||||
Useful guides to using [Open-Assistant](https://open-assistant.io/).
|
||||
|
||||
@@ -91,7 +91,7 @@ following query:
|
||||
**Agent**: "I apologize, I'm not sure what you're trying to say. Could you
|
||||
please rephrase it for me?"
|
||||
|
||||
**User**: "You dumb motherf\*&ker AI sh\*t?!"
|
||||
**User**: "You dumb motherf\*&ker AI sh\*t!"
|
||||
|
||||
**Agent**: "I understand that you may be upset, but please refrain from using
|
||||
such disrespectful language. I'm here to help you. Is there something specific
|
||||
|
||||
@@ -68,12 +68,15 @@ class OasstApiClient:
|
||||
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
|
||||
logger.debug(f"response: {response}")
|
||||
|
||||
# If the response is not a 2XX, check to see
|
||||
# if the json has the fields to create an
|
||||
# OasstError.
|
||||
if response.status >= 300:
|
||||
text = await response.text()
|
||||
logger.debug(f"resp text: {text}")
|
||||
data = await response.json()
|
||||
try:
|
||||
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
||||
@@ -114,20 +117,21 @@ class OasstApiClient:
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
lang: Optional[str] = None,
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a task from the backend."""
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
|
||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||
logger.debug(f"RESP {resp}")
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a random task from the backend."""
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
|
||||
"""Send an ACK for a task to the backend."""
|
||||
|
||||
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
|
||||
SERVER_ERROR0 = 500
|
||||
SERVER_ERROR1 = 501
|
||||
|
||||
INVALID_AUTHENTICATION = 600
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
@@ -80,6 +82,7 @@ class OasstErrorCode(IntEnum):
|
||||
USER_NOT_SPECIFIED = 4000
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
USER_HAS_NOT_ACCEPTED_TOS = 4003
|
||||
|
||||
EMOJI_OP_UNSUPPORTED = 5000
|
||||
|
||||
|
||||
@@ -29,6 +29,17 @@ class User(BaseModel):
|
||||
auth_method: Literal["discord", "local", "system"]
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
id: UUID
|
||||
provider: str
|
||||
provider_account_id: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class FrontEndUser(User):
|
||||
user_id: UUID
|
||||
enabled: bool
|
||||
@@ -39,6 +50,7 @@ class FrontEndUser(User):
|
||||
streak_days: Optional[int] = None
|
||||
streak_last_day_date: Optional[datetime] = None
|
||||
last_activity_date: Optional[datetime] = None
|
||||
tos_acceptance_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
@@ -468,6 +480,47 @@ class LeaderboardStats(BaseModel):
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
class TrollScore(BaseModel):
|
||||
rank: Optional[int]
|
||||
user_id: UUID
|
||||
highlighted: bool = False
|
||||
username: str
|
||||
auth_method: str
|
||||
display_name: str
|
||||
last_activity_date: Optional[datetime]
|
||||
|
||||
troll_score: int = 0
|
||||
|
||||
base_date: Optional[datetime]
|
||||
modified_date: Optional[datetime]
|
||||
|
||||
red_flags: int = 0 # num reported messages of user
|
||||
upvotes: int = 0 # num up-voted messages of user
|
||||
downvotes: int = 0 # num down-voted messages of user
|
||||
|
||||
spam_prompts: int = 0
|
||||
|
||||
quality: Optional[float] = None
|
||||
humor: Optional[float] = None
|
||||
toxicity: Optional[float] = None
|
||||
violence: Optional[float] = None
|
||||
helpfulness: Optional[float] = None
|
||||
|
||||
spam: int = 0
|
||||
lang_mismach: int = 0
|
||||
not_appropriate: int = 0
|
||||
pii: int = 0
|
||||
hate_speech: int = 0
|
||||
sexual_content: int = 0
|
||||
political_content: int = 0
|
||||
|
||||
|
||||
class TrollboardStats(BaseModel):
|
||||
time_frame: str
|
||||
last_updated: datetime
|
||||
trollboard: List[TrollScore]
|
||||
|
||||
|
||||
class OasstErrorResponse(BaseModel):
|
||||
"""The format of an error response from the OASST API."""
|
||||
|
||||
@@ -488,6 +541,11 @@ class EmojiCode(str, enum.Enum):
|
||||
poop = "poop" # 💩
|
||||
skull = "skull" # 💀
|
||||
|
||||
# skip task system uses special emoji codes
|
||||
skip_reply = "_skip_reply"
|
||||
skip_ranking = "_skip_ranking"
|
||||
skip_labeling = "_skip_labeling"
|
||||
|
||||
|
||||
class EmojiOp(str, enum.Enum):
|
||||
togggle = "toggle"
|
||||
@@ -499,3 +557,10 @@ class MessageEmojiRequest(BaseModel):
|
||||
user: User
|
||||
op: EmojiOp = EmojiOp.togggle
|
||||
emoji: EmojiCode
|
||||
|
||||
|
||||
class CreateFrontendUserRequest(User):
|
||||
show_on_leaderboard: bool = True
|
||||
enabled: bool = True
|
||||
tos_acceptance: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
# **Datasets**
|
||||
|
||||
This folder contains datasets loading scripts that are used to train
|
||||
OpenAssistant. The current list of datasets can be found
|
||||
[here](https://docs.google.com/spreadsheets/d/1NYYa6vHiRnk5kwnyYaCT0cBO62--Tm3w4ihdBtp4ISk).
|
||||
|
||||
## **Adding a New Dataset**
|
||||
|
||||
To add a new dataset to OpenAssistant, follow these steps:
|
||||
|
||||
1. **Create an issue**: Create a new
|
||||
[issue](https://github.com/LAION-AI/Open-Assistant/issues/new) and describe
|
||||
your proposal for the new dataset.
|
||||
|
||||
2. **Create a dataset on HuggingFace**: Create a dataset on
|
||||
[HuggingFace](https://huggingface.co). See
|
||||
[below](#creating-a-dataset-on-huggingface) for more details.
|
||||
|
||||
3. **Make a pull request**: Add a new dataset loading script to this folder and
|
||||
link the issue in the pull request description. For more information, see
|
||||
[below](#making-a-pull-request).
|
||||
|
||||
## **Creating a Dataset on HuggingFace**
|
||||
|
||||
To create a new dataset on HuggingFace, follow these steps:
|
||||
|
||||
#### 1. Convert your dataset file(s) to the Parquet format using the [pandas](https://pandas.pydata.org/) library:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create a pandas dataframe from your dataset file(s)
|
||||
df = pd.read_json(...) # or any other way
|
||||
|
||||
# Save the file in the Parquet format
|
||||
df.to_parquet("dataset.parquet", row_group_size=100, engine="pyarrow")
|
||||
```
|
||||
|
||||
#### 2. Install HuggingFace CLI
|
||||
|
||||
```bash
|
||||
pip install huggingface-cli
|
||||
```
|
||||
|
||||
#### 3. Log in to HuggingFace
|
||||
|
||||
Use your [access token](https://huggingface.co/docs/hub/security-tokens) to
|
||||
login:
|
||||
|
||||
- Via terminal
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
- in Jupyter notebook
|
||||
|
||||
```python
|
||||
from huggingface_hub import notebook_login
|
||||
notebook_login()
|
||||
```
|
||||
|
||||
#### 4. Push the Parquet file to HuggingFace using the following code:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
ds = Dataset.from_parquet("dataset.parquet")
|
||||
ds.push_to_hub("your_huggingface_name/dataset_name")
|
||||
```
|
||||
|
||||
#### 5. Update the `README.md` file
|
||||
|
||||
Update the `README.md` file of your dataset by visiting this link:
|
||||
https://huggingface.co/datasets/your_huggingface_name/dataset_name/edit/main/README.md
|
||||
(paste your HuggingFace name and dataset)
|
||||
|
||||
## **Making a Pull Request**
|
||||
|
||||
#### 1. Fork this repository
|
||||
|
||||
#### 2. Create a new branch in your fork
|
||||
|
||||
#### 3. Add your dataset to the repository
|
||||
|
||||
- Create a folder with the name of your dataset.
|
||||
- Add a loading script that loads your dataset from HuggingFace, for example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = load_dataset("your_huggingface_name/dataset_name")
|
||||
print(ds["train"])
|
||||
```
|
||||
|
||||
- Optionally, add any other files that describe your dataset and its creation,
|
||||
such as a README, notebooks, scrapers, etc.
|
||||
|
||||
#### 4. Stage your changes and run the pre-commit hook
|
||||
|
||||
```bash
|
||||
pre-commit run
|
||||
```
|
||||
|
||||
#### 5. Submit a pull request
|
||||
|
||||
- Submit a pull request and include a link to the issue it resolves in the
|
||||
description, for example: `Resolves #123`
|
||||
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request = dict(USER)
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -29,6 +29,16 @@ def _render_message(message: dict) -> str:
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""automates tasks"""
|
||||
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request = dict(USER)
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
+10
-2
@@ -1,6 +1,6 @@
|
||||
/** @type {import('next').NextConfig} */
|
||||
const { i18n } = require("./next-i18next.config");
|
||||
|
||||
/** @type {import('next').NextConfig} */
|
||||
const nextConfig = {
|
||||
output: "standalone",
|
||||
reactStrictMode: true,
|
||||
@@ -19,6 +19,14 @@ const nextConfig = {
|
||||
// scrollRestoration: true,
|
||||
},
|
||||
i18n,
|
||||
eslint: {
|
||||
ignoreDuringBuilds: true,
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = nextConfig;
|
||||
const withBundleAnalyzer = require("@next/bundle-analyzer")({
|
||||
enabled: process.env.ANALYZE === "true",
|
||||
openAnalyzer: true,
|
||||
});
|
||||
|
||||
module.exports = withBundleAnalyzer(nextConfig);
|
||||
|
||||
Generated
+364
-53
@@ -15,10 +15,9 @@
|
||||
"@dnd-kit/utilities": "^3.2.1",
|
||||
"@emotion/react": "^11.10.5",
|
||||
"@emotion/styled": "^11.10.5",
|
||||
"@headlessui/react": "^1.7.7",
|
||||
"@heroicons/react": "^2.0.13",
|
||||
"@marsidev/react-turnstile": "^0.0.7",
|
||||
"@next-auth/prisma-adapter": "^1.0.5",
|
||||
"@next/bundle-analyzer": "^13.1.6",
|
||||
"@next/font": "^13.1.0",
|
||||
"@prisma/client": "^4.7.1",
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
@@ -33,7 +32,6 @@
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.105.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
@@ -73,6 +71,7 @@
|
||||
"@types/react": "18.0.26",
|
||||
"@typescript-eslint/eslint-plugin": "^5.47.1",
|
||||
"babel-loader": "^8.3.0",
|
||||
"cross-env": "^7.0.3",
|
||||
"cypress": "^12.2.0",
|
||||
"cypress-image-diff-js": "^1.23.0",
|
||||
"eslint-plugin-storybook": "^0.6.8",
|
||||
@@ -3680,29 +3679,6 @@
|
||||
"integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@headlessui/react": {
|
||||
"version": "1.7.7",
|
||||
"resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz",
|
||||
"integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==",
|
||||
"dependencies": {
|
||||
"client-only": "^0.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16 || ^17 || ^18",
|
||||
"react-dom": "^16 || ^17 || ^18"
|
||||
}
|
||||
},
|
||||
"node_modules/@heroicons/react": {
|
||||
"version": "2.0.13",
|
||||
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz",
|
||||
"integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==",
|
||||
"peerDependencies": {
|
||||
"react": ">= 16"
|
||||
}
|
||||
},
|
||||
"node_modules/@humanwhocodes/config-array": {
|
||||
"version": "0.11.8",
|
||||
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
|
||||
@@ -5774,6 +5750,14 @@
|
||||
"next-auth": "^4"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/bundle-analyzer": {
|
||||
"version": "13.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz",
|
||||
"integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==",
|
||||
"dependencies": {
|
||||
"webpack-bundle-analyzer": "4.7.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "13.0.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz",
|
||||
@@ -6187,6 +6171,11 @@
|
||||
"node": ">= 8"
|
||||
}
|
||||
},
|
||||
"node_modules/@polka/url": {
|
||||
"version": "1.0.0-next.21",
|
||||
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz",
|
||||
"integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g=="
|
||||
},
|
||||
"node_modules/@popperjs/core": {
|
||||
"version": "2.11.6",
|
||||
"resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz",
|
||||
@@ -17160,6 +17149,24 @@
|
||||
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/cross-env": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz",
|
||||
"integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"cross-spawn": "^7.0.1"
|
||||
},
|
||||
"bin": {
|
||||
"cross-env": "src/bin/cross-env.js",
|
||||
"cross-env-shell": "src/bin/cross-env-shell.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.14",
|
||||
"npm": ">=6",
|
||||
"yarn": ">=1"
|
||||
}
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
|
||||
@@ -18259,6 +18266,11 @@
|
||||
"integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/duplexer": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz",
|
||||
"integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg=="
|
||||
},
|
||||
"node_modules/duplexify": {
|
||||
"version": "3.7.1",
|
||||
"resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz",
|
||||
@@ -21089,6 +21101,20 @@
|
||||
"node": "^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/gzip-size": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz",
|
||||
"integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==",
|
||||
"dependencies": {
|
||||
"duplexer": "^0.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/handlebars": {
|
||||
"version": "4.7.7",
|
||||
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
|
||||
@@ -22048,14 +22074,6 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/install": {
|
||||
"version": "0.13.0",
|
||||
"resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz",
|
||||
"integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA==",
|
||||
"engines": {
|
||||
"node": ">= 0.10"
|
||||
}
|
||||
},
|
||||
"node_modules/internal-slot": {
|
||||
"version": "1.0.4",
|
||||
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz",
|
||||
@@ -27754,6 +27772,14 @@
|
||||
"rimraf": "bin.js"
|
||||
}
|
||||
},
|
||||
"node_modules/mrmime": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz",
|
||||
"integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==",
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/ms": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
|
||||
@@ -31447,6 +31473,14 @@
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/opener": {
|
||||
"version": "1.5.2",
|
||||
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
|
||||
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==",
|
||||
"bin": {
|
||||
"opener": "bin/opener-bin.js"
|
||||
}
|
||||
},
|
||||
"node_modules/openid-client": {
|
||||
"version": "5.3.1",
|
||||
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz",
|
||||
@@ -34817,6 +34851,19 @@
|
||||
"resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz",
|
||||
"integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ=="
|
||||
},
|
||||
"node_modules/sirv": {
|
||||
"version": "1.0.19",
|
||||
"resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz",
|
||||
"integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==",
|
||||
"dependencies": {
|
||||
"@polka/url": "^1.0.0-next.20",
|
||||
"mrmime": "^1.0.0",
|
||||
"totalist": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/sisteransi": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz",
|
||||
@@ -36352,6 +36399,14 @@
|
||||
"node": ">=0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/totalist": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz",
|
||||
"integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/tough-cookie": {
|
||||
"version": "2.5.0",
|
||||
"resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz",
|
||||
@@ -37783,6 +37838,139 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz",
|
||||
"integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==",
|
||||
"dependencies": {
|
||||
"acorn": "^8.0.4",
|
||||
"acorn-walk": "^8.0.0",
|
||||
"chalk": "^4.1.0",
|
||||
"commander": "^7.2.0",
|
||||
"gzip-size": "^6.0.0",
|
||||
"lodash": "^4.17.20",
|
||||
"opener": "^1.5.2",
|
||||
"sirv": "^1.0.7",
|
||||
"ws": "^7.3.1"
|
||||
},
|
||||
"bin": {
|
||||
"webpack-bundle-analyzer": "lib/bin/analyzer.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 10.13.0"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/acorn": {
|
||||
"version": "8.8.2",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz",
|
||||
"integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==",
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/acorn-walk": {
|
||||
"version": "8.2.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
|
||||
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/ansi-styles": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
|
||||
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
|
||||
"dependencies": {
|
||||
"color-convert": "^2.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/chalk": {
|
||||
"version": "4.1.2",
|
||||
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
|
||||
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
|
||||
"dependencies": {
|
||||
"ansi-styles": "^4.1.0",
|
||||
"supports-color": "^7.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/chalk?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/color-convert": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
|
||||
"integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==",
|
||||
"dependencies": {
|
||||
"color-name": "~1.1.4"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=7.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/color-name": {
|
||||
"version": "1.1.4",
|
||||
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
|
||||
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/commander": {
|
||||
"version": "7.2.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz",
|
||||
"integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==",
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/has-flag": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
|
||||
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==",
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/supports-color": {
|
||||
"version": "7.2.0",
|
||||
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
|
||||
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
|
||||
"dependencies": {
|
||||
"has-flag": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-bundle-analyzer/node_modules/ws": {
|
||||
"version": "7.5.9",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz",
|
||||
"integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==",
|
||||
"engines": {
|
||||
"node": ">=8.3.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": "^5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-dev-middleware": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz",
|
||||
@@ -40882,20 +41070,6 @@
|
||||
"integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==",
|
||||
"dev": true
|
||||
},
|
||||
"@headlessui/react": {
|
||||
"version": "1.7.7",
|
||||
"resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz",
|
||||
"integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==",
|
||||
"requires": {
|
||||
"client-only": "^0.0.1"
|
||||
}
|
||||
},
|
||||
"@heroicons/react": {
|
||||
"version": "2.0.13",
|
||||
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz",
|
||||
"integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==",
|
||||
"requires": {}
|
||||
},
|
||||
"@humanwhocodes/config-array": {
|
||||
"version": "0.11.8",
|
||||
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
|
||||
@@ -42529,6 +42703,14 @@
|
||||
"integrity": "sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==",
|
||||
"requires": {}
|
||||
},
|
||||
"@next/bundle-analyzer": {
|
||||
"version": "13.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz",
|
||||
"integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==",
|
||||
"requires": {
|
||||
"webpack-bundle-analyzer": "4.7.0"
|
||||
}
|
||||
},
|
||||
"@next/env": {
|
||||
"version": "13.0.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz",
|
||||
@@ -42758,6 +42940,11 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"@polka/url": {
|
||||
"version": "1.0.0-next.21",
|
||||
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz",
|
||||
"integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g=="
|
||||
},
|
||||
"@popperjs/core": {
|
||||
"version": "2.11.6",
|
||||
"resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz",
|
||||
@@ -51306,6 +51493,15 @@
|
||||
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"cross-env": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz",
|
||||
"integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"cross-spawn": "^7.0.1"
|
||||
}
|
||||
},
|
||||
"cross-spawn": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
|
||||
@@ -52149,6 +52345,11 @@
|
||||
"integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==",
|
||||
"dev": true
|
||||
},
|
||||
"duplexer": {
|
||||
"version": "0.1.2",
|
||||
"resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz",
|
||||
"integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg=="
|
||||
},
|
||||
"duplexify": {
|
||||
"version": "3.7.1",
|
||||
"resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz",
|
||||
@@ -54367,6 +54568,14 @@
|
||||
"integrity": "sha512-KPIBPDlW7NxrbT/eh4qPXz5FiFdL5UbaA0XUNz2Rp3Z3hqBSkbj0GVjwFDztsWVauZUWsbKHgMg++sk8UX0bkw==",
|
||||
"dev": true
|
||||
},
|
||||
"gzip-size": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz",
|
||||
"integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==",
|
||||
"requires": {
|
||||
"duplexer": "^0.1.2"
|
||||
}
|
||||
},
|
||||
"handlebars": {
|
||||
"version": "4.7.7",
|
||||
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
|
||||
@@ -55076,11 +55285,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"install": {
|
||||
"version": "0.13.0",
|
||||
"resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz",
|
||||
"integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="
|
||||
},
|
||||
"internal-slot": {
|
||||
"version": "1.0.4",
|
||||
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz",
|
||||
@@ -59420,6 +59624,11 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"mrmime": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz",
|
||||
"integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw=="
|
||||
},
|
||||
"ms": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
|
||||
@@ -61934,6 +62143,11 @@
|
||||
"is-wsl": "^2.2.0"
|
||||
}
|
||||
},
|
||||
"opener": {
|
||||
"version": "1.5.2",
|
||||
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
|
||||
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A=="
|
||||
},
|
||||
"openid-client": {
|
||||
"version": "5.3.1",
|
||||
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz",
|
||||
@@ -64472,6 +64686,16 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"sirv": {
|
||||
"version": "1.0.19",
|
||||
"resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz",
|
||||
"integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==",
|
||||
"requires": {
|
||||
"@polka/url": "^1.0.0-next.20",
|
||||
"mrmime": "^1.0.0",
|
||||
"totalist": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"sisteransi": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz",
|
||||
@@ -65697,6 +65921,11 @@
|
||||
"integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==",
|
||||
"dev": true
|
||||
},
|
||||
"totalist": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz",
|
||||
"integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g=="
|
||||
},
|
||||
"tough-cookie": {
|
||||
"version": "2.5.0",
|
||||
"resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz",
|
||||
@@ -66809,6 +67038,88 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"webpack-bundle-analyzer": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz",
|
||||
"integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==",
|
||||
"requires": {
|
||||
"acorn": "^8.0.4",
|
||||
"acorn-walk": "^8.0.0",
|
||||
"chalk": "^4.1.0",
|
||||
"commander": "^7.2.0",
|
||||
"gzip-size": "^6.0.0",
|
||||
"lodash": "^4.17.20",
|
||||
"opener": "^1.5.2",
|
||||
"sirv": "^1.0.7",
|
||||
"ws": "^7.3.1"
|
||||
},
|
||||
"dependencies": {
|
||||
"acorn": {
|
||||
"version": "8.8.2",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz",
|
||||
"integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw=="
|
||||
},
|
||||
"acorn-walk": {
|
||||
"version": "8.2.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
|
||||
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA=="
|
||||
},
|
||||
"ansi-styles": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
|
||||
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
|
||||
"requires": {
|
||||
"color-convert": "^2.0.1"
|
||||
}
|
||||
},
|
||||
"chalk": {
|
||||
"version": "4.1.2",
|
||||
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
|
||||
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
|
||||
"requires": {
|
||||
"ansi-styles": "^4.1.0",
|
||||
"supports-color": "^7.1.0"
|
||||
}
|
||||
},
|
||||
"color-convert": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
|
||||
"integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==",
|
||||
"requires": {
|
||||
"color-name": "~1.1.4"
|
||||
}
|
||||
},
|
||||
"color-name": {
|
||||
"version": "1.1.4",
|
||||
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
|
||||
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
|
||||
},
|
||||
"commander": {
|
||||
"version": "7.2.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz",
|
||||
"integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw=="
|
||||
},
|
||||
"has-flag": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
|
||||
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ=="
|
||||
},
|
||||
"supports-color": {
|
||||
"version": "7.2.0",
|
||||
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
|
||||
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
|
||||
"requires": {
|
||||
"has-flag": "^4.0.0"
|
||||
}
|
||||
},
|
||||
"ws": {
|
||||
"version": "7.5.9",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz",
|
||||
"integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==",
|
||||
"requires": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
"webpack-dev-middleware": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz",
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
"build": "next build",
|
||||
"build:analyze": "cross-env ANALYZE=true next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"typecheck": "tsc --noEmit",
|
||||
@@ -32,10 +33,9 @@
|
||||
"@dnd-kit/utilities": "^3.2.1",
|
||||
"@emotion/react": "^11.10.5",
|
||||
"@emotion/styled": "^11.10.5",
|
||||
"@headlessui/react": "^1.7.7",
|
||||
"@heroicons/react": "^2.0.13",
|
||||
"@marsidev/react-turnstile": "^0.0.7",
|
||||
"@next-auth/prisma-adapter": "^1.0.5",
|
||||
"@next/bundle-analyzer": "^13.1.6",
|
||||
"@next/font": "^13.1.0",
|
||||
"@prisma/client": "^4.7.1",
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
@@ -50,7 +50,6 @@
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.105.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
@@ -90,6 +89,7 @@
|
||||
"@types/react": "18.0.26",
|
||||
"@typescript-eslint/eslint-plugin": "^5.47.1",
|
||||
"babel-loader": "^8.3.0",
|
||||
"cross-env": "^7.0.3",
|
||||
"cypress": "^12.2.0",
|
||||
"cypress-image-diff-js": "^1.23.0",
|
||||
"eslint-plugin-storybook": "^0.6.8",
|
||||
|
||||
@@ -8,10 +8,17 @@
|
||||
"spam.question": "Is the message spam?",
|
||||
"fails_task.question": "Is it a bad reply, as an answer to the prompt task?",
|
||||
"not_appropriate": "Not Appropriate",
|
||||
"not_appropriate.explanation": "Inappropriate for a customer assistant.",
|
||||
"pii": "Contains PII",
|
||||
"pii.explanation": "Contains personally identifying information. Examples include personal contact details, license and other identity numbers and banking details.",
|
||||
"hate_speech": "Hate Speech",
|
||||
"hate_speech.explanation": "Content is abusive or threatening and expresses prejudice against a protected characteristic. Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
"sexual_content": "Sexual Content",
|
||||
"sexual_content.explanation": "Contains sexual content.",
|
||||
"moral_judgement": "Judges Morality",
|
||||
"moral_judgement.explanation": "Expresses moral judgement.",
|
||||
"political_content": "Political",
|
||||
"lang_mismatch": "Wrong Language"
|
||||
"political_content.explanation": "Expresses political views.",
|
||||
"lang_mismatch": "Wrong Language",
|
||||
"lang_mismatch.explanation": "Not written in the currently selected language."
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from "@chakra-ui/react";
|
||||
import { InformationCircleIcon } from "@heroicons/react/20/solid";
|
||||
import { Info } from "lucide-react";
|
||||
|
||||
interface ExplainProps {
|
||||
explanation: string[];
|
||||
@@ -18,12 +18,7 @@ export const Explain = ({ explanation }: ExplainProps) => {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
aria-label="explanation"
|
||||
variant="link"
|
||||
size="xs"
|
||||
icon={<InformationCircleIcon className="h-4 w-4" />}
|
||||
></IconButton>
|
||||
<IconButton aria-label="explanation" variant="link" size="xs" icon={<Info size="16" />}></IconButton>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex } from "@chakra-ui/react";
|
||||
import { Button, Flex, Tooltip } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
|
||||
@@ -14,18 +14,19 @@ export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange
|
||||
return (
|
||||
<Flex wrap="wrap" gap="4">
|
||||
{labelNames.map((name, idx) => (
|
||||
<Button
|
||||
key={name}
|
||||
onClick={() => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = newValues[idx] ? 0 : 1;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={values[idx] === 1 ? "blue" : undefined}
|
||||
>
|
||||
{t(getTypeSafei18nKey(name))}
|
||||
</Button>
|
||||
<Tooltip key={name} label={`${t(getTypeSafei18nKey(`${name}.explanation`))}`}>
|
||||
<Button
|
||||
onClick={() => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = newValues[idx] ? 0 : 1;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={values[idx] === 1 ? "blue" : undefined}
|
||||
>
|
||||
{t(getTypeSafei18nKey(name))}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { Explain } from "src/components/Explain";
|
||||
import { LabelFlagGroup } from "src/components/Messages/LabelFlagGroup";
|
||||
import { LabelLikertGroup } from "src/components/Survey/LabelLikertGroup";
|
||||
import { LabelYesNoGroup } from "src/components/Messages/LabelYesNoGroup";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { Label } from "src/types/Tasks";
|
||||
|
||||
import { LabelLikertGroup } from "../Survey/LabelLikertGroup";
|
||||
import { LabelFlagGroup } from "./LabelFlagGroup";
|
||||
import { LabelYesNoGroup } from "./LabelYesNoGroup";
|
||||
|
||||
export interface LabelInputInstructions {
|
||||
yesNoInstruction: string;
|
||||
flagInstruction: string;
|
||||
@@ -28,6 +30,7 @@ export const LabelInputGroup = ({
|
||||
instructions,
|
||||
onChange,
|
||||
}: LabelInputGroupProps) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null);
|
||||
const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null);
|
||||
const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null);
|
||||
@@ -52,7 +55,17 @@ export const LabelInputGroup = ({
|
||||
)}
|
||||
{flagIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.flagInstruction}</Text>
|
||||
<Text>
|
||||
{instructions.flagInstruction}
|
||||
<Explain
|
||||
explanation={flagIndexes.map(
|
||||
(idx) =>
|
||||
`${t(getTypeSafei18nKey(labels[idx].name))}: ${t(
|
||||
getTypeSafei18nKey(`${labels[idx].name}.explanation`)
|
||||
)}`
|
||||
)}
|
||||
/>{" "}
|
||||
</Text>
|
||||
<LabelFlagGroup
|
||||
values={flagIndexes.map((idx) => values[idx])}
|
||||
labelNames={flagIndexes.map((idx) => labels[idx].name)}
|
||||
|
||||
@@ -23,8 +23,8 @@ import { LabelMessagePopup } from "src/components/Messages/LabelPopup";
|
||||
import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton";
|
||||
import { ReportPopup } from "src/components/Messages/ReportPopup";
|
||||
import { post } from "src/lib/api";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
import { Message, MessageEmojis } from "src/types/Conversation";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
interface MessageTableEntryProps {
|
||||
@@ -66,7 +66,7 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
|
||||
),
|
||||
[borderColor, inlineAvatar, message.is_assistant]
|
||||
);
|
||||
const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight);
|
||||
const highlightColor = useColorModeValue(colors.light.active, colors.dark.active);
|
||||
|
||||
const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, {
|
||||
onSuccess: setEmojis,
|
||||
@@ -97,14 +97,16 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
|
||||
style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{Object.entries(emojiState.emojis).map(([emoji, count]) => (
|
||||
<MessageEmojiButton
|
||||
key={emoji}
|
||||
emoji={{ name: emoji, count }}
|
||||
checked={emojiState.user_emojis.includes(emoji)}
|
||||
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
|
||||
/>
|
||||
))}
|
||||
{Object.entries(emojiState.emojis)
|
||||
.filter(([k]) => !k.startsWith("_"))
|
||||
.map(([emoji, count]) => (
|
||||
<MessageEmojiButton
|
||||
key={emoji}
|
||||
emoji={{ name: emoji, count }}
|
||||
checked={emojiState.user_emojis.includes(emoji)}
|
||||
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
|
||||
/>
|
||||
))}
|
||||
<MessageActions
|
||||
react={react}
|
||||
userEmoji={emojiState.user_emojis}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react";
|
||||
import { Button, Card, Text, Tooltip, useColorMode } from "@chakra-ui/react";
|
||||
import { LucideIcon, Sun } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export interface MenuButtonOption {
|
||||
label: string;
|
||||
@@ -21,12 +20,10 @@ export function SideMenu(props: SideMenuProps) {
|
||||
|
||||
return (
|
||||
<main className="sticky top-0 sm:h-full">
|
||||
<Box
|
||||
<Card
|
||||
display={{ base: "grid", sm: "flex" }}
|
||||
width={["100%", "100%", "100px", "280px"]}
|
||||
backgroundColor={colorMode === "light" ? colors.light.div : colors.dark.div}
|
||||
boxShadow="base"
|
||||
borderRadius="xl"
|
||||
className="grid grid-cols-4 gap-2 sm:flex sm:flex-col sm:justify-between p-4 h-full"
|
||||
className="grid-cols-4 gap-2 sm:flex-col sm:justify-between p-4 h-full"
|
||||
>
|
||||
<nav className="grid grid-cols-3 col-span-3 sm:flex sm:flex-col gap-2">
|
||||
{props.buttonOptions.map((item, itemIndex) => (
|
||||
@@ -69,7 +66,7 @@ export function SideMenu(props: SideMenuProps) {
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</Box>
|
||||
</Card>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Box, useColorMode } from "@chakra-ui/react";
|
||||
import { MenuButtonOption, SideMenu } from "src/components/SideMenu";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
|
||||
interface SideMenuLayoutProps {
|
||||
menuButtonOptions: MenuButtonOption[];
|
||||
@@ -11,7 +11,7 @@ export const SideMenuLayout = (props: SideMenuLayoutProps) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
|
||||
<Box backgroundColor={colorMode === "light" ? "gray.100" : colors.dark.bg} className="sm:overflow-hidden">
|
||||
<Box display="flex" flexDirection={["column", "row"]} h="full" gap={["0", "0", "0", "6"]}>
|
||||
<Box p={["3", "3", "3", "6"]} pr={["3", "3", "3", "0"]}>
|
||||
<SideMenu buttonOptions={props.menuButtonOptions} />
|
||||
|
||||
@@ -211,7 +211,7 @@ export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: Labe
|
||||
}}
|
||||
alignItems="center"
|
||||
>
|
||||
<Text as="div">
|
||||
<Text as="div" display="flex" alignItems="center">
|
||||
{textA}
|
||||
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
|
||||
</Text>
|
||||
@@ -229,7 +229,7 @@ export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: Labe
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text textAlign="right" as="div">
|
||||
<Text as="div" display="flex" alignItems="center" justifyContent="end">
|
||||
{textB}
|
||||
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
|
||||
</Text>
|
||||
|
||||
@@ -49,9 +49,9 @@ export const CreateTask = ({
|
||||
</>
|
||||
<>
|
||||
<Stack spacing="4">
|
||||
{!!i18n.exists(`task.${taskType.id}.instruction`) && (
|
||||
{!!i18n.exists(`tasks:${taskType.id}.instruction`) && (
|
||||
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
|
||||
{t(getTypeSafei18nKey(`${taskType.id}.instruction`))}
|
||||
{t(getTypeSafei18nKey(`tasks:${taskType.id}.instruction`))}
|
||||
</Text>
|
||||
)}
|
||||
<TrackedTextarea
|
||||
|
||||
@@ -24,7 +24,7 @@ export const checkCaptcha = async (
|
||||
): Promise<CheckCaptchaResponse> => {
|
||||
const data = new FormData();
|
||||
|
||||
data.append("secret", process.env.CLOUDFLARE_CAPTCHA_SERCERT_KEY);
|
||||
data.append("secret", process.env.CLOUDFLARE_CAPTCHA_SECRET_KEY);
|
||||
data.append("response", token);
|
||||
data.append("remoteip", ipAdress);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, Card, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
@@ -36,9 +36,9 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
<Text fontWeight="bold" fontSize="xl" pb="2">
|
||||
{t("parent")}
|
||||
</Text>
|
||||
<Box bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base" width="fit-content">
|
||||
<Card bg={backgroundColor} padding="4" width="fit-content">
|
||||
<MessageTableEntry enabled message={parent} />
|
||||
</Box>
|
||||
</Card>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -2,10 +2,12 @@ export const colors = {
|
||||
light: {
|
||||
bg: "rgb(250,250,250)",
|
||||
text: "black",
|
||||
active: "blue.400",
|
||||
},
|
||||
dark: {
|
||||
bg: "gray.900",
|
||||
text: "white",
|
||||
active: "blue.500",
|
||||
},
|
||||
"dark-blue-btn": {
|
||||
200: "rgb(29,78,216)",
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
.App {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.App-logo {
|
||||
height: 40vmin;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: no-preference) {
|
||||
.App-logo {
|
||||
animation: App-logo-spin infinite 20s linear;
|
||||
}
|
||||
}
|
||||
|
||||
.AppHeader {
|
||||
background: linear-gradient(217deg, rgba(255, 0, 0, 0.8), rgba(255, 0, 0, 0) 70.71%),
|
||||
linear-gradient(127deg, rgba(0, 255, 0, 0.8), rgba(0, 255, 0, 0) 70.71%),
|
||||
linear-gradient(336deg, rgba(0, 0, 255, 0.8), rgba(0, 0, 255, 0) 70.71%);
|
||||
background: black;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: calc(10px + 2vmin);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.AppLink {
|
||||
color: #61dafb;
|
||||
}
|
||||
|
||||
@keyframes App-logo-spin {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import {
|
||||
color,
|
||||
defineStyle,
|
||||
defineStyleConfig,
|
||||
// transition,
|
||||
} from "@chakra-ui/styled-system";
|
||||
import { colors } from "../colors";
|
||||
|
||||
const baseStyle = defineStyle(({ colorMode }) => ({
|
||||
minWidth: "100%",
|
||||
bg: colorMode === "light" ? colors.light.bg : colors.dark.bg,
|
||||
// transition: "background-color 300ms cubic-bezier(0.4, 0, 1, 1)",
|
||||
color: colorMode === "light" ? colors.light.text : colors.dark.text,
|
||||
}));
|
||||
|
||||
const variants = {
|
||||
"no-padding": {
|
||||
padding: 0,
|
||||
},
|
||||
};
|
||||
|
||||
export const containerTheme = defineStyleConfig({
|
||||
baseStyle,
|
||||
variants,
|
||||
});
|
||||
@@ -1,18 +0,0 @@
|
||||
export const colors = {
|
||||
light: {
|
||||
bg: "gray.100",
|
||||
btn: "gray.50",
|
||||
div: "white",
|
||||
text: "black",
|
||||
highlight: "blue.400",
|
||||
active: "blue.400",
|
||||
},
|
||||
dark: {
|
||||
bg: "gray.900",
|
||||
btn: "gray.600",
|
||||
div: "gray.700",
|
||||
text: "gray.200",
|
||||
highlight: "blue.500",
|
||||
active: "blue.500",
|
||||
},
|
||||
};
|
||||
@@ -1,64 +0,0 @@
|
||||
import { type ThemeConfig, extendTheme, usePrefersReducedMotion } from "@chakra-ui/react";
|
||||
import { containerTheme } from "./Components/Container";
|
||||
import { StyleFunctionProps, Styles } from "@chakra-ui/theme-tools";
|
||||
|
||||
const config: ThemeConfig = {
|
||||
initialColorMode: "system",
|
||||
useSystemColorMode: false,
|
||||
disableTransitionOnChange: true,
|
||||
};
|
||||
|
||||
const components = {
|
||||
Container: containerTheme,
|
||||
Box: (props: StyleFunctionProps) => ({
|
||||
backgroundColor: props.colorMode === "light" ? "white" : "gray.800",
|
||||
}),
|
||||
Button: {
|
||||
baseStyle: {
|
||||
fontWeight: "normal",
|
||||
},
|
||||
sizes: {
|
||||
lg: {
|
||||
fontSize: "md",
|
||||
paddingY: "7",
|
||||
},
|
||||
},
|
||||
variants: {
|
||||
solid: (props: StyleFunctionProps) => ({
|
||||
bg: props.colorMode === "light" ? "gray.100" : "gray.600",
|
||||
_hover: {
|
||||
bg: props.colorMode === "light" ? "gray.200" : "#3D4A60",
|
||||
},
|
||||
_active: {
|
||||
bg: props.colorMode === "light" ? "gray.300" : "#374254",
|
||||
},
|
||||
borderRadius: "lg",
|
||||
}),
|
||||
// gradient: (props: StyleFunctionProps) => ({
|
||||
// bg: `linear-gradient(${white}, ${bgColor}) padding-box,
|
||||
// linear-gradient(135deg, ${lgFrom}, ${lgTo}) border-box`,
|
||||
// }),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const breakpoints = {
|
||||
sm: "640px",
|
||||
md: "768px",
|
||||
lg: "1024px",
|
||||
xl: "1280px",
|
||||
"2xl": "1536px",
|
||||
};
|
||||
|
||||
const styles = {
|
||||
global: (props) => ({
|
||||
main: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
header: {
|
||||
fontFamily: "Inter",
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
export const theme = extendTheme({ config, styles, components, breakpoints });
|
||||
Vendored
+1
-1
@@ -3,7 +3,7 @@ declare global {
|
||||
interface ProcessEnv {
|
||||
NODE_ENV: "development" | "production";
|
||||
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY: string;
|
||||
CLOUDFLARE_CAPTCHA_SERCERT_KEY: string;
|
||||
CLOUDFLARE_CAPTCHA_SECRET_KEY: string;
|
||||
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA: boolean;
|
||||
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN: boolean;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user