Merge branch 'main' into 911_control_email_signin

This commit is contained in:
notmd
2023-02-03 13:31:18 +07:00
61 changed files with 2092 additions and 801 deletions
+9 -8
View File
@@ -33,14 +33,6 @@ jobs:
WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }}
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
${{ secrets.DEV_WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY:
${{ secrets.DEV_WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
${{ secrets.DEV_WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
S3_REGION: ${{ secrets.S3_REGION }}
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
@@ -49,11 +41,20 @@ jobs:
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
MESSAGE_SIZE_LIMIT: ${{ vars.MESSAGE_SIZE_LIMIT }}
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
STATS_INTERVAL_DAY: ${{ vars.STATS_INTERVAL_DAY }}
STATS_INTERVAL_WEEK: ${{ vars.STATS_INTERVAL_WEEK }}
STATS_INTERVAL_MONTH: ${{ vars.STATS_INTERVAL_MONTH }}
STATS_INTERVAL_TOTAL: ${{ vars.STATS_INTERVAL_TOTAL }}
WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
${{ secrets.WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY }}
WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY:
${{ secrets.WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA }}
WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN:
${{ vars.WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN }}
steps:
- name: Checkout
uses: actions/checkout@v2
+1 -1
View File
@@ -140,4 +140,4 @@ automatically deploy the built release to the dev machine.
### Contribute a Dataset
See
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/docs/docs/data/datasets.md)
[here](https://github.com/LAION-AI/Open-Assistant/blob/main/openassistant/datasets/README.md)
+5 -2
View File
@@ -122,6 +122,9 @@
TREE_MANAGER__MAX_CHILDREN_COUNT:
"{{ lookup('ansible.builtin.env', 'MAX_CHILDREN_COUNT') |
default('3', true) }}"
MESSAGE_SIZE_LIMIT:
"{{ lookup('ansible.builtin.env', 'MESSAGE_SIZE_LIMIT') |
default('2000', true) }}"
USER_STATS_INTERVAL_DAY:
"{{ lookup('ansible.builtin.env', 'STATS_INTERVAL_DAY') |
default('5', true) }}"
@@ -175,9 +178,9 @@
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY:
"{{ lookup('ansible.builtin.env',
'WEB_NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY') }}"
CLOUDFLARE_CAPTCHA_SERCERT_KEY:
CLOUDFLARE_CAPTCHA_SECRET_KEY:
"{{ lookup('ansible.builtin.env',
'WEB_CLOUDFLARE_CAPTCHA_SERCERT_KEY') }}"
'WEB_CLOUDFLARE_CAPTCHA_SECRET_KEY') }}"
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA:
"{{ lookup('ansible.builtin.env',
'WEB_NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA') }}"
@@ -36,6 +36,7 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user_stats", "streak_days")
op.drop_column("user_stats", "streak_last_day_date")
op.drop_column("user", "streak_days")
op.drop_column("user", "streak_last_day_date")
op.drop_column("user", "last_activity_date")
# ### end Alembic commands ###
@@ -0,0 +1,34 @@
"""add tos_acceptance_date to user
Revision ID: 55361f323d12
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 00:22:08.280251
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55361f323d12"
down_revision = "f60958968ff8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
op.drop_column("user_stats", "streak_days")
op.drop_column("user_stats", "streak_last_day_date")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
)
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column("user", "tos_acceptance_date")
# ### end Alembic commands ###
@@ -0,0 +1,27 @@
"""add won_prompt_lottery_date to mts
Revision ID: f60958968ff8
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 10:10:38.301707
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "f60958968ff8"
down_revision = "7b8f0011e0b0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message_tree_state", sa.Column("won_prompt_lottery_date", sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message_tree_state", "won_prompt_lottery_date")
# ### end Alembic commands ###
@@ -0,0 +1,30 @@
"""add skip bool & skip_reason to task
Revision ID: 9e7ec4a9e3f2
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 21:46:49.971052
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "9e7ec4a9e3f2"
down_revision = "55361f323d12"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("task", sa.Column("skipped", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("task", sa.Column("skip_reason", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("task", "skip_reason")
op.drop_column("task", "skipped")
# ### end Alembic commands ###
@@ -0,0 +1,59 @@
"""add troll_stats
Revision ID: 4d7e0b0ebe84
Revises: 9e7ec4a9e3f2
Create Date: 2023-02-02 15:44:12.647260
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4d7e0b0ebe84"
down_revision = "9e7ec4a9e3f2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"troll_stats",
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("base_date", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("troll_score", sa.Integer(), nullable=False),
sa.Column("rank", sa.Integer(), nullable=True),
sa.Column("red_flags", sa.Integer(), nullable=False),
sa.Column("upvotes", sa.Integer(), nullable=False),
sa.Column("downvotes", sa.Integer(), nullable=False),
sa.Column("spam_prompts", sa.Integer(), nullable=False),
sa.Column("quality", sa.Float(), nullable=True),
sa.Column("humor", sa.Float(), nullable=True),
sa.Column("toxicity", sa.Float(), nullable=True),
sa.Column("violence", sa.Float(), nullable=True),
sa.Column("helpfulness", sa.Float(), nullable=True),
sa.Column("spam", sa.Integer(), nullable=False),
sa.Column("lang_mismach", sa.Integer(), nullable=False),
sa.Column("not_appropriate", sa.Integer(), nullable=False),
sa.Column("pii", sa.Integer(), nullable=False),
sa.Column("hate_speech", sa.Integer(), nullable=False),
sa.Column("sexual_content", sa.Integer(), nullable=False),
sa.Column("political_content", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "time_frame"),
)
op.create_index("ix_troll_stats__timeframe__user_id", "troll_stats", ["time_frame", "user_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_troll_stats__timeframe__user_id", table_name="troll_stats")
op.drop_table("troll_stats")
# ### end Alembic commands ###
@@ -0,0 +1,39 @@
"""Add Account table
Revision ID: 8c8241d1f973
Revises: 4d7e0b0ebe84
Create Date: 2023-01-30 15:10:58.776315
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8c8241d1f973"
down_revision = "4d7e0b0ebe84"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"account",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("provider", "account", ["provider_account_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("provider", table_name="account")
op.drop_table("account")
# ### end Alembic commands ###
+1
View File
@@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA:
ur = UserRepository(db=session, api_client=api_client)
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
pr = PromptRepository(
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
+2
View File
@@ -10,6 +10,7 @@ from oasst_backend.api.v1 import (
stats,
tasks,
text_labels,
trollboards,
users,
)
@@ -22,6 +23,7 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(trollboards.router, prefix="/trollboards", tags=["trollboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
@@ -59,6 +59,37 @@ def query_frontend_user(
return user.to_protocol_frontend_user()
@router.post("/", response_model=protocol.FrontEndUser)
def create_frontend_user(
*,
create_user: protocol.CreateFrontendUserRequest,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
user = ur.lookup_client_user(create_user, create_missing=True)
def changed(a, b) -> bool:
return a is not None and a != b
# only call update_user if something changed
if (
changed(create_user.enabled, user.enabled)
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
or changed(create_user.notes, user.notes)
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
):
user = ur.update_user(
user.id,
enabled=create_user.enabled,
show_on_leaderboard=create_user.show_on_leaderboard,
tos_acceptance=create_user.tos_acceptance,
notes=create_user.notes,
)
return user.to_protocol_frontend_user()
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
auth_method: str,
+73
View File
@@ -0,0 +1,73 @@
import aiohttp
from fastapi import APIRouter, Depends, HTTPException, Request
from oasst_backend import auth
from oasst_backend.api import deps
from oasst_backend.config import Settings
from oasst_backend.models import Account
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_401_UNAUTHORIZED
router = APIRouter()
@router.get("/discord")
def login_discord(request: Request):
redirect_uri = f"{get_callback_uri(request)}/discord"
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
raise HTTPException(status_code=302, headers={"location": auth_url})
@router.get("/callback/discord", response_model=protocol_schema.Token)
async def callback_discord(
auth_code: str,
request: Request,
db: Session = Depends(deps.get_db),
):
redirect_uri = f"{get_callback_uri(request)}/discord"
async with aiohttp.ClientSession(raise_for_status=True) as session:
# Exchange the auth code for a Discord access token
async with session.post(
"https://discord.com/api/oauth2/token",
data={
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
"scope": "identify",
},
) as token_response:
token_response_json = await token_response.json()
access_token = token_response_json["access_token"]
# Retrieve user's Discord information using access token
async with session.get(
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
) as user_response:
user_response_json = await user_response.json()
discord_id = user_response_json["id"]
account: Account = auth.get_account_from_discord_id(db, discord_id)
if not account:
# Discord account is not linked to an OA account
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
# Discord account is valid and linked to an OA account -> create JWT
access_token = auth.create_access_token(account)
return protocol_schema.Token(access_token=access_token, token_type="bearer")
def get_callback_uri(request: Request):
"""
Gets the URI for the base callback endpoint with no provider name appended.
"""
# This seems ugly, not sure if there is a better way
current_url = str(request.url)
domain = current_url.split("/api/v1/")[0]
redirect_uri = f"{domain}/api/v1/callback"
return redirect_uri
+1 -1
View File
@@ -131,7 +131,7 @@ def tasks_acknowledge_failure(
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
pr.task_repository.acknowledge_task_failure(task_id)
pr.skip_task(task_id=task_id, reason=nack_request.reason)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -0,0 +1,21 @@
from typing import Optional
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import TrollboardStats
from sqlmodel import Session
router = APIRouter()
@router.get("/{time_frame}", response_model=TrollboardStats)
def get_trollboard(
time_frame: UserStatsTimeFrame,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
) -> TrollboardStats:
usr = UserStatsRepository(db)
return usr.get_trollboard(time_frame, limit=max_count)
+2 -1
View File
@@ -191,6 +191,7 @@ def update_user(
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
tos_acceptance: Optional[bool] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
@@ -198,7 +199,7 @@ def update_user(
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
+37
View File
@@ -0,0 +1,37 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import jwt
from oasst_backend.config import Settings
from oasst_backend.models import Account
from sqlmodel import Session
def create_access_token(data: dict) -> str:
"""
Create an encoded JSON Web Token (JWT) using the given data.
"""
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
return encoded_jwt
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
"""
Get the Open-Assistant Account associated with the given Discord ID.
"""
account: Account = (
db.query(Account)
.filter(
Account.provider == "discord",
Account.provider_account_id == discord_id,
)
.first()
)
return account
+10 -2
View File
@@ -13,7 +13,7 @@ class TreeManagerConfiguration(BaseModel):
No new initial prompt tasks are handed out to users if this
number is reached."""
max_tree_depth: int = 6
max_tree_depth: int = 3
"""Maximum depth of message tree."""
max_children_count: int = 3
@@ -22,7 +22,7 @@ class TreeManagerConfiguration(BaseModel):
num_prompter_replies: int = 1
"""Number of prompter replies to collect per assistant reply."""
goal_tree_size: int = 15
goal_tree_size: int = 12
"""Total number of messages to gather per tree."""
num_reviews_initial_prompt: int = 3
@@ -135,6 +135,11 @@ class Settings(BaseSettings):
AUTH_LENGTH: int = 32
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
AUTH_COOKIE_NAME: str = "next-auth.session-token"
AUTH_ALGORITHM: str = "HS256"
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
AUTH_DISCORD_CLIENT_ID: str = ""
AUTH_DISCORD_CLIENT_SECRET: str = ""
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
@@ -158,6 +163,9 @@ class Settings(BaseSettings):
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
True # TODO: set False after ToS acceptance UI was added to web-frontend
)
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
+2
View File
@@ -8,6 +8,7 @@ from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
from .task import Task
from .text_labels import TextLabels
from .troll_stats import TrollStats
from .user import User
from .user_stats import UserStats, UserStatsTimeFrame
@@ -26,4 +27,5 @@ __all__ = [
"Journal",
"JournalIntegration",
"MessageEmoji",
"TrollStats",
]
@@ -1,4 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
@@ -46,6 +48,9 @@ class State(str, Enum):
BACKLOG_RANKING = "backlog_ranking"
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting"
"""Initial prompt has passed spam check, waiting to be drawn to grow."""
VALID_STATES = (
State.INITIAL_PROMPT_REVIEW,
@@ -63,6 +68,7 @@ TERMINAL_STATES = (
State.SCORING_FAILED,
State.HALTED_BY_MODERATOR,
State.BACKLOG_RANKING,
State.PROMPT_LOTTERY_WAITING,
)
@@ -78,3 +84,4 @@ class MessageTreeState(SQLModel, table=True):
state: str = Field(nullable=False, max_length=128, index=True)
active: bool = Field(nullable=False, index=True)
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
+2
View File
@@ -31,6 +31,8 @@ class Task(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
ack: Optional[bool] = None
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
skipped: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
skip_reason: str = Field(nullable=True, max_length=512)
frontend_message_id: Optional[str] = None
message_tree_id: Optional[UUID] = None
parent_message_id: Optional[UUID] = None
@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class TrollStats(SQLModel, table=True):
__tablename__ = "troll_stats"
__table_args__ = (Index("ix_troll_stats__timeframe__user_id", "time_frame", "user_id", unique=True),)
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
)
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
troll_score: int = 0
modified_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: float = Field(nullable=True)
humor: float = Field(nullable=True)
toxicity: float = Field(nullable=True)
violence: float = Field(nullable=True)
helpfulness: float = Field(nullable=True)
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
def compute_troll_score(self) -> int:
return (
self.red_flags * 3
- self.upvotes
+ self.downvotes
+ self.spam_prompts
+ self.lang_mismach
+ self.not_appropriate
+ self.pii
+ self.hate_speech
+ self.sexual_content
+ self.political_content
)
+18
View File
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
)
# terms of service acceptance date
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
def to_protocol_frontend_user(self):
return protocol.FrontEndUser(
user_id=self.id,
@@ -55,4 +58,19 @@ class User(SQLModel, table=True):
streak_days=self.streak_days,
streak_last_day_date=self.streak_last_day_date,
last_activity_date=self.last_activity_date,
tos_acceptance_date=self.tos_acceptance_date,
)
class Account(SQLModel, table=True):
__tablename__ = "account"
__table_args__ = (Index("provider", "provider_account_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
user_id: UUID = Field(foreign_key="user.id")
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
provider_account_id: str = Field(nullable=False, max_length=128)
+58 -8
View File
@@ -35,7 +35,21 @@ from oasst_shared.utils import unaware_to_utc, utcnow
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
_task_type_and_reaction = (
(
(db_payload.PrompterReplyPayload, db_payload.AssistantReplyPayload),
protocol_schema.EmojiCode.skip_reply,
),
(
(db_payload.LabelInitialPromptPayload, db_payload.LabelConversationReplyPayload),
protocol_schema.EmojiCode.skip_labeling,
),
(
(db_payload.RankInitialPromptsPayload, db_payload.RankConversationRepliesPayload),
protocol_schema.EmojiCode.skip_ranking,
),
)
class PromptRepository:
@@ -77,7 +91,14 @@ class PromptRepository:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
if self.user.deleted or not self.user.enabled:
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
raise OasstError(
"User has not accepted terms of service.",
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
)
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
@@ -90,7 +111,7 @@ class PromptRepository:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.",
OasstErrorCode.MESSAGE_NOT_FOUND,
HTTP_404_NOT_FOUND,
HTTPStatus.NOT_FOUND,
)
return message
@@ -139,7 +160,12 @@ class PromptRepository:
return message
def _validate_task(
self, task: Task, *, task_id: Optional[UUID] = None, frontend_message_id: Optional[str] = None
self,
task: Task,
*,
task_id: Optional[UUID] = None,
frontend_message_id: Optional[str] = None,
check_ack: bool = True,
) -> Task:
if task is None:
if task_id:
@@ -150,7 +176,7 @@ class PromptRepository:
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
if check_ack and not task.ack:
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
@@ -675,7 +701,7 @@ class PromptRepository:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
return message
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
@@ -874,7 +900,7 @@ class PromptRepository:
if api_client_id != self.api_client.id:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
qry = self.db.query(Message)
if user_id:
@@ -995,7 +1021,31 @@ WHERE message.id = cc.id;
message_trees=result.get(None, 0),
)
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
@managed_tx_method()
def skip_task(self, task_id: UUID, reason: str):
self.ensure_user_is_enabled()
task = self.task_repository.fetch_task_by_id(task_id)
self._validate_task(task, check_ack=False)
if not task.collective:
task.skipped = True
task.skip_reason = reason
self.db.add(task)
def handle_cancel_emoji(task_payload: db_payload.TaskPayload) -> Message | None:
for types, emoji in _task_type_and_reaction:
for t in types:
if isinstance(task_payload, t):
return self.handle_message_emoji(task.parent_message_id, protocol_schema.EmojiOp.add, emoji)
return None
task_payload: db_payload.TaskPayload = task.payload.payload
handle_cancel_emoji(task_payload)
def handle_message_emoji(
self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema.EmojiCode
) -> Message:
self.ensure_user_is_enabled()
message = self.fetch_message(message_id)
-15
View File
@@ -167,21 +167,6 @@ class TaskRepository:
task.done = True
self.db.add(task)
@managed_tx_method(CommitMode.COMMIT)
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
@managed_tx_method(CommitMode.COMMIT)
def insert_task(
self,
+138 -22
View File
@@ -13,7 +13,16 @@ from fastapi.encoders import jsonable_encoder
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
from oasst_backend.models import (
Message,
MessageEmoji,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
User,
message_tree_state,
)
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils import tree_export
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
@@ -21,7 +30,8 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, not_, or_, text, update
from oasst_shared.utils import utcnow
from sqlmodel import Session, and_, func, not_, or_, text, update
class TaskType(Enum):
@@ -153,7 +163,7 @@ class TreeManager:
def _determine_task_availability_internal(
self,
num_active_trees: int,
num_missing_prompts: int,
extendible_parents: list[ExtendibleParentRow],
prompts_need_review: list[Message],
replies_need_review: list[Message],
@@ -161,8 +171,7 @@ class TreeManager:
) -> dict[protocol_schema.TaskRequestType, int]:
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(1, num_missing_prompts)
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
@@ -194,6 +203,72 @@ class TreeManager:
return task_count_by_type
def _prompt_lottery(self, lang: str) -> int:
MAX_RETRIES = 5
retry = 0
while True:
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
num_missing_prompts = self.cfg.max_active_trees - num_active_trees
if num_missing_prompts <= 0:
return 0
# select among distinct users
authors_qry = (
self.db.query(Message.user_id)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.distinct(Message.user_id)
)
author_ids = authors_qry.all()
if len(author_ids) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_prompts} trees missing for {lang=})."
)
return num_missing_prompts
# first select an authour
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
# select random prompt of author
qry = (
self.db.query(MessageTreeState, Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.user_id == prompt_author_id,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
)
.limit(100)
)
prompt_candidates = qry.all()
if len(prompt_candidates) == 0:
retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
if retry < MAX_RETRIES:
continue
else:
logger.warning("Max retries in prompt lottery reached.")
return num_missing_prompts
winner_prompt = random.choice(prompt_candidates)
message: Message = winner_prompt.Message
logger.info(f"Prompt lottery winner: {message.id=}")
mts: MessageTreeState = winner_prompt.MessageTreeState
self._enter_state(mts, message_tree_state.State.GROWING)
self.db.flush()
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()
@@ -201,14 +276,14 @@ class TreeManager:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
num_missing_prompts = self._prompt_lottery(lang=lang)
extendible_parents, _ = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
num_missing_prompts=num_missing_prompts,
extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
@@ -238,7 +313,8 @@ class TreeManager:
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
num_missing_prompts = self._prompt_lottery(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents, active_tree_sizes = self.query_extendible_parents(lang=lang)
@@ -256,7 +332,7 @@ class TreeManager:
num_ranking_tasks=len(incomplete_rankings),
num_replies_need_review=len(replies_need_review),
num_prompts_need_review=len(prompts_need_review),
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
num_missing_prompts=num_missing_prompts,
num_missing_replies=num_missing_replies,
)
@@ -268,7 +344,7 @@ class TreeManager:
)
else:
task_count_by_type = self._determine_task_availability_internal(
num_active_trees=num_active_trees,
num_missing_prompts=num_missing_prompts,
extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
@@ -611,7 +687,7 @@ class TreeManager:
)
else:
self.enter_low_grade_state(msg.message_tree_id)
self.check_condition_for_growing_state(msg.message_tree_id)
self.check_condition_for_prompt_lottery(msg.message_tree_id)
elif msg.review_count >= self.cfg.num_reviews_reply:
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
msg.review_result = True
@@ -649,6 +725,8 @@ class TreeManager:
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
self.activate_backlog_tree(lang=root_msg.lang)
else:
if mts.state == message_tree_state.State.GROWING and mts.won_prompt_lottery_date is None:
mts.won_prompt_lottery_date = utcnow()
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
def enter_low_grade_state(self, message_tree_id: UUID) -> None:
@@ -656,8 +734,8 @@ class TreeManager:
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
def check_condition_for_prompt_lottery(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_prompt_lottery({message_tree_id=})")
mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.INITIAL_PROMPT_REVIEW:
@@ -670,7 +748,7 @@ class TreeManager:
logger.debug(f"False {initial_prompt.review_result=}")
return False
self._enter_state(mts, message_tree_state.State.GROWING)
self._enter_state(mts, message_tree_state.State.PROMPT_LOTTERY_WAITING)
return True
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
@@ -820,6 +898,14 @@ class TreeManager:
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.outerjoin(
MessageEmoji,
and_(
Message.id == MessageEmoji.message_id,
MessageEmoji.user_id == self.pr.user_id,
MessageEmoji.emoji == protocol_schema.EmojiCode.skip_labeling,
),
)
.filter(
MessageTreeState.active,
MessageTreeState.state == state,
@@ -827,6 +913,7 @@ class TreeManager:
not_(Message.deleted),
Message.review_count < required_reviews,
Message.lang == lang,
MessageEmoji.message_id.is_(None),
)
)
@@ -879,12 +966,18 @@ SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) chi
mts.message_tree_id
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
LEFT JOIN message_emoji me on
(m.parent_id = me.message_id
AND :skip_user_id IS NOT NULL
AND me.user_id = :skip_user_id
AND me.emoji = :skip_ranking)
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND m.lang = :lang -- matches lang
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
AND me.message_id IS NULL -- no skip ranking emoji for user
GROUP BY m.parent_id, m.role, mts.message_tree_id
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
@@ -910,6 +1003,8 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
"ranking_state": message_tree_state.State.RANKING,
"lang": lang,
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_ranking": protocol_schema.EmojiCode.skip_ranking,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
@@ -919,6 +1014,11 @@ HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message_emoji me ON
(m.id = me.message_id
AND :skip_user_id IS NOT NULL
AND me.user_id = :skip_user_id
AND me.emoji = :skip_reply)
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
@@ -926,6 +1026,7 @@ WHERE mts.active -- only consider active trees
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND m.lang = :lang -- parent matches lang
AND me.message_id IS NULL -- no skip reply emoji for user
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
@@ -946,6 +1047,8 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"num_prompter_replies": self.cfg.num_prompter_replies,
"lang": lang,
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
},
)
@@ -984,6 +1087,8 @@ HAVING COUNT(m.id) < mts.goal_tree_size
"num_prompter_replies": self.cfg.num_prompter_replies,
"lang": lang,
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
@@ -1097,7 +1202,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
f"Checking state of {len(prompt_review_trees)} active message trees in 'initial_prompt_review' state."
)
for t in prompt_review_trees:
self.check_condition_for_growing_state(t.message_tree_id)
self.check_condition_for_prompt_lottery(t.message_tree_id)
growing_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
@@ -1288,11 +1393,17 @@ DELETE FROM message WHERE message_tree_id = :message_tree_id;
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
def _reactivate_tree(self, mts: MessageTreeState):
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
if mts.state == message_tree_state.State.PROMPT_LOTTERY_WAITING:
return
tree_id = mts.message_tree_id
if self.check_condition_for_growing_state(tree_id):
if mts.won_prompt_lottery_date is not None:
self._enter_state(mts, message_tree_state.State.GROWING)
if self.check_condition_for_ranking_state(tree_id):
self.check_condition_for_scoring_state(tree_id)
else:
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
self.check_condition_for_prompt_lottery(tree_id)
@managed_tx_method(CommitMode.FLUSH)
def purge_user_messages(
@@ -1450,7 +1561,8 @@ if __name__ == "__main__":
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
# dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
dummy_user = protocol_schema.User(id="1234", display_name="bulb", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
cfg = TreeManagerConfiguration()
@@ -1465,14 +1577,18 @@ if __name__ == "__main__":
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
xs = tm.query_prompts_need_review(lang="en")
print("xs", len(xs))
for x in xs:
print(x.id, x.emojis)
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review(lang="en"))
# print("query_extendible_trees", tm.query_extendible_trees())
# print("query_extendible_parents", tm.query_extendible_parents())
# print("next_task:", tm.next_task())
print(
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
)
# print(
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
# )
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
+12 -2
View File
@@ -73,7 +73,8 @@ class UserRepository:
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
) -> None:
tos_acceptance: Optional[bool] = None,
) -> User:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
@@ -94,8 +95,11 @@ class UserRepository:
user.notes = notes
if show_on_leaderboard is not None:
user.show_on_leaderboard = show_on_leaderboard
if tos_acceptance:
user.tos_acceptance_date = utcnow()
self.db.add(user)
return user
@managed_tx_method(CommitMode.COMMIT)
def mark_user_deleted(self, id: UUID) -> None:
@@ -143,8 +147,10 @@ class UserRepository:
display_name=display_name,
api_client_id=self.api_client.id,
auth_method=auth_method,
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
)
if auth_method == "system":
user.show_on_leaderboard = False # don't show system users, e.g. import user
user.tos_acceptance_date = utcnow()
self.db.add(user)
elif display_name and display_name != user.display_name:
# we found the user but the display name changed
@@ -156,6 +162,10 @@ class UserRepository:
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
if not client_user:
return None
if not (client_user.auth_method and client_user.id):
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
for i in range(num_retries):
try:
+236 -8
View File
@@ -5,19 +5,31 @@ from uuid import UUID
import sqlalchemy as sa
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models import (
Message,
MessageReaction,
MessageTreeState,
Task,
TextLabels,
TrollStats,
User,
UserStats,
UserStatsTimeFrame,
)
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
LabelPrompterReplyPayload,
RankingReactionPayload,
)
from oasst_shared.schemas.protocol import LeaderboardStats, UserScore
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore
from oasst_shared.utils import log_timing, utcnow
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, delete, func, text
def _create_user_score(r, highlighted_user_id: UUID | None):
def _create_user_score(r, highlighted_user_id: UUID | None) -> UserScore:
if r["UserStats"]:
d = r["UserStats"].dict()
else:
@@ -37,6 +49,24 @@ def _create_user_score(r, highlighted_user_id: UUID | None):
return UserScore(**d)
def _create_troll_score(r, highlighted_user_id: UUID | None) -> TrollScore:
if r["TrollStats"]:
d = r["TrollStats"].dict()
else:
d = {"modified_date": utcnow()}
for k in [
"user_id",
"username",
"auth_method",
"display_name",
"last_activity_date",
]:
d[k] = r[k]
if highlighted_user_id:
d["highlighted"] = r["user_id"] == highlighted_user_id
return TrollScore(**d)
class UserStatsRepository:
def __init__(self, session: Session):
self.session = session
@@ -133,6 +163,38 @@ class UserStatsRepository:
stats_by_timeframe = {tf.value: _create_user_score(r, user_id) for tf in UserStatsTimeFrame}
return stats_by_timeframe
def get_trollboard(
self,
time_frame: UserStatsTimeFrame,
limit: int = 100,
highlighted_user_id: Optional[UUID] = None,
) -> TrollboardStats:
"""
Get trollboard stats for the specified time frame
"""
qry = (
self.session.query(
User.id.label("user_id"),
User.username,
User.auth_method,
User.display_name,
User.last_activity_date,
TrollStats,
)
.join(TrollStats, User.id == TrollStats.user_id)
.filter(TrollStats.time_frame == time_frame.value)
.order_by(TrollStats.rank)
.limit(limit)
)
trollboard = [_create_troll_score(r, highlighted_user_id) for r in self.session.exec(qry)]
if len(trollboard) > 0:
last_update = max(x.modified_date for x in trollboard)
else:
last_update = utcnow()
return TrollboardStats(time_frame=time_frame.value, trollboard=trollboard, last_updated=last_update)
def query_total_prompts_per_user(
self, reference_time: Optional[datetime] = None, only_reviewed: Optional[bool] = True
):
@@ -292,10 +354,145 @@ class UserStatsRepository:
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_ranks(time_frame=time_frame)
self.update_leader_ranks(time_frame=time_frame)
def query_message_emoji_counts_per_user(self, reference_time: Optional[datetime] = None):
qry = self.session.query(
Message.user_id,
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_up].cast(sa.Integer), 0)).label("up"),
func.sum(coalesce(Message.emojis[EmojiCode.thumbs_down].cast(sa.Integer), 0)).label("down"),
func.sum(coalesce(Message.emojis[EmojiCode.red_flag].cast(sa.Integer), 0)).label("flag"),
).filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_spam_prompts_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(Message.user_id, func.count().label("spam_prompts"))
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.state == TreeState.ABORTED_LOW_GRADE)
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def query_labels_per_user(self, reference_time: Optional[datetime] = None):
qry = (
self.session.query(
Message.user_id,
func.sum(coalesce(TextLabels.labels[TextLabel.spam].cast(sa.Integer), 0)).label("spam"),
func.sum(coalesce(TextLabels.labels[TextLabel.lang_mismatch].cast(sa.Integer), 0)).label(
"lang_mismach"
),
func.sum(coalesce(TextLabels.labels[TextLabel.not_appropriate].cast(sa.Integer), 0)).label(
"not_appropriate"
),
func.sum(coalesce(TextLabels.labels[TextLabel.pii].cast(sa.Integer), 0)).label("pii"),
func.sum(coalesce(TextLabels.labels[TextLabel.hate_speech].cast(sa.Integer), 0)).label("hate_speech"),
func.sum(coalesce(TextLabels.labels[TextLabel.sexual_content].cast(sa.Integer), 0)).label(
"sexual_content"
),
func.sum(coalesce(TextLabels.labels[TextLabel.political_content].cast(sa.Integer), 0)).label(
"political_content"
),
func.avg(TextLabels.labels[TextLabel.quality].cast(sa.Float)).label("quality"),
func.avg(TextLabels.labels[TextLabel.humor].cast(sa.Float)).label("humor"),
func.avg(TextLabels.labels[TextLabel.toxicity].cast(sa.Float)).label("toxicity"),
func.avg(TextLabels.labels[TextLabel.violence].cast(sa.Float)).label("violence"),
func.avg(TextLabels.labels[TextLabel.helpfulness].cast(sa.Float)).label("helpfulness"),
)
.select_from(TextLabels)
.join(Message, TextLabels.message_id == Message.id)
.filter(Message.deleted == sa.false(), Message.emojis.is_not(None))
)
if reference_time:
qry = qry.filter(Message.created_date >= reference_time)
qry = qry.group_by(Message.user_id)
return qry
def _update_troll_stats_internal(self, time_frame: UserStatsTimeFrame, base_date: Optional[datetime] = None):
# gather user data
time_frame_key = time_frame.value
stats_by_user: dict[UUID, TrollStats] = dict()
now = utcnow()
def get_stats(id: UUID) -> TrollStats:
us = stats_by_user.get(id)
if not us:
us = TrollStats(user_id=id, time_frame=time_frame_key, modified_date=now, base_date=base_date)
stats_by_user[id] = us
return us
# emoji counts of user's messages
qry = self.query_message_emoji_counts_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
s.upvotes = r["up"]
s.downvotes = r["down"]
s.red_flags = r["flag"]
# num spam prompts
qry = self.query_spam_prompts_per_user(reference_time=base_date)
for r in qry:
uid, count = r
s = get_stats(uid).spam_prompts = count
label_field_names = (
"quality",
"humor",
"toxicity",
"violence",
"helpfulness",
"spam",
"lang_mismach",
"not_appropriate",
"pii",
"hate_speech",
"sexual_content",
"political_content",
)
# label counts / mean values
qry = self.query_labels_per_user(reference_time=base_date)
for r in qry:
uid = r["user_id"]
s = get_stats(uid)
for fn in label_field_names:
setattr(s, fn, r[fn])
# delete all existing stast for time frame
d = delete(TrollStats).where(TrollStats.time_frame == time_frame_key)
self.session.execute(d)
if None in stats_by_user:
logger.warning("Some messages in DB have NULL values in user_id column.")
del stats_by_user[None]
# compute magic leader score
for v in stats_by_user.values():
v.troll_score = v.compute_troll_score()
# insert user objects
self.session.add_all(stats_by_user.values())
self.session.flush()
self.update_troll_ranks(time_frame=time_frame)
@log_timing(log_kwargs=True)
def update_ranks(self, time_frame: UserStatsTimeFrame = None):
def update_leader_ranks(self, time_frame: UserStatsTimeFrame = None):
"""
Update user_stats ranks. The persisted rank values allow to
quickly the rank of a single user and to query nearby users.
@@ -329,10 +526,41 @@ WHERE
r = self.session.execute(
text(sql_update_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks updated({time_frame=}) {r.rowcount} rows.")
logger.debug(f"pre_compute_ranks leader updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(self, time_frame: UserStatsTimeFrame, reference_time: Optional[datetime] = None):
self._update_stats_internal(time_frame, reference_time)
@log_timing(log_kwargs=True)
def update_troll_ranks(self, time_frame: UserStatsTimeFrame = None):
sql_update_troll_rank = """
-- update rank
UPDATE troll_stats ts
SET "rank" = r."rank"
FROM
(SELECT
ROW_NUMBER () OVER(
PARTITION BY time_frame
ORDER BY troll_score DESC, user_id
) AS "rank", user_id, time_frame
FROM troll_stats ts2
WHERE (:time_frame IS NULL OR time_frame = :time_frame)) AS r
WHERE
ts.user_id = r.user_id
AND ts.time_frame = r.time_frame;"""
r = self.session.execute(
text(sql_update_troll_rank), {"time_frame": time_frame.value if time_frame is not None else None}
)
logger.debug(f"pre_compute_ranks troll updated({time_frame=}) {r.rowcount} rows.")
def update_stats_time_frame(
self,
time_frame: UserStatsTimeFrame,
reference_time: Optional[datetime] = None,
leader_stats: bool = True,
troll_stats: bool = True,
):
if leader_stats:
self._update_stats_internal(time_frame, reference_time)
if troll_stats:
self._update_troll_stats_internal(time_frame, reference_time)
self.session.commit()
@log_timing(log_kwargs=True, level="INFO")
+1
View File
@@ -1,3 +1,4 @@
aiohttp==3.8.3
alembic==1.8.1
cryptography==39.0.0
fastapi==0.88.0
+1 -1
View File
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY=""
OASST_API_KEY="1234"
+388 -391
View File
@@ -9,27 +9,25 @@ import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.messages import (
assistant_reply_message,
assistant_reply_messages,
confirm_label_response_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_message,
invalid_user_input_embed,
label_assistant_reply_message,
label_initial_prompt_message,
label_prompter_reply_message,
initial_prompt_messages,
label_assistant_reply_messages,
label_prompter_reply_messages,
plain_embed,
prompter_reply_message,
prompter_reply_messages,
rank_assistant_reply_message,
rank_initial_prompts_message,
rank_prompter_reply_message,
rank_conversation_reply_messages,
rank_initial_prompts_messages,
rank_prompter_reply_messages,
task_complete_embed,
)
from bot.settings import Settings
from loguru import logger
from oasst_shared.api_client import OasstApiClient, TaskType
from oasst_shared.api_client import OasstApiClient
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin")
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
settings = Settings()
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
class _TaskHandler(t.Generic[_Task_contra]):
"""Handle user interaction for a task."""
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
"""Create a new `TaskHandler`.
Args:
ctx (lightbulb.Context): The context of the command that started the task.
task (_Task_contra): The task to handle.
"""
self.ctx = ctx
self.task = task
self.task_messages = self.get_task_messages(task)
self.sent_messages: list[hikari.Message] = []
@staticmethod
def get_task_messages(task: _Task_contra) -> list[str]:
"""Get the messages to send to the user for the task."""
raise NotImplementedError
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
"""Send the task and wait for the user to accept/skip/cancel it."""
# Send all but the last message because we need to attach buttons to the last one
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
for task_msg in self.task_messages[:-1]:
if len(task_msg) > 2000:
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
task_msg = task_msg[:1999]
self.sent_messages.append(await self.ctx.author.send(task_msg))
# Send the last message with buttons
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
await task_accept_view.start(last_msg)
await task_accept_view.wait()
return task_accept_view.choice
async def handle(self) -> None:
"""Handle the user's response to the task.
This method should be called after `send` has been called."""
# Ack task to the backend
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
# Loop until the user's input is accepted
while True:
try:
# Wait for user to send a message
event = await self.ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
predicate=lambda e: (
e.author_id == self.ctx.author.id
and e.message.content is not None
and not e.message.content.startswith(settings.prefix)
),
timeout=MAX_TASK_TIME,
)
# Validate the message
if event.content is None or not self.check_user_input(event.content):
await self.ctx.author.send("Invalid input")
continue
# Confirm user input
if not (await self.confirm_user_input(event.content)):
continue
# Message is valid and confirmed by user
break
except asyncio.TimeoutError:
return
next_task = await self.notify(event.content, event)
if not isinstance(next_task, protocol_schema.TaskDone):
raise TypeError(f"Unknown task type: {next_task!r}")
return
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
"""Notify the backend that the user completed the task."""
raise NotImplementedError
async def confirm_user_input(self, content: str) -> bool:
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
raise NotImplementedError
def check_user_input(self, content: str) -> bool:
"""Check the user's response to the task. Returns True if the response is valid."""
raise NotImplementedError
async def cancel(self, reason: str = "not specified") -> None:
"""Cancel the task."""
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.nack_task(self.task.id, reason)
_Ranking_contra = t.TypeVar(
"_Ranking_contra",
bound=protocol_schema.RankAssistantRepliesTask
| protocol_schema.RankInitialPromptsTask
| protocol_schema.RankPrompterRepliesTask
| protocol_schema.RankConversationRepliesTask,
contravariant=True,
)
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
"""This should not be used directly. Use its subclasses instead."""
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
task = await oasst_api.post_interaction(
protocol_schema.MessageRanking(
user=protocol_schema.User(
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
),
ranking=[int(r) - 1 for r in content.split(",")],
message_id=f"{self.sent_messages[0].id}",
)
)
db: Connection = self.ctx.bot.d.db
async with db.cursor() as cursor:
row = await (
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
).fetchone()
log_channel = row[0] if row else None
log_messages: list[hikari.Message] = []
if log_channel is not None:
for message in self.task_messages[:-1]:
msg = await self.ctx.bot.rest.create_message(log_channel, message)
log_messages.append(msg)
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
return task
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
return rank_assistant_reply_message(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
super().__init__(ctx, task)
@staticmethod
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
return rank_initial_prompts_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.prompt_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
return rank_prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
return rank_conversation_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
@staticmethod
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
return initial_prompt_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
return prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
return assistant_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
def check_user_input(self, content: str) -> bool:
user_labels = content.split(",")
return (
all([l in self.task.valid_labels for l in user_labels])
and self.task.mandatory_labels is not None
and all([m in user_labels for m in self.task.mandatory_labels])
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
return label_assistant_reply_messages(task)
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
return label_prompter_reply_messages(task)
summarize_story = "summarize_story"
rate_summary = "rate_summary"
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.random),
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work(ctx: lightbulb.Context):
"""Create and handle a task."""
# Only send this message if started from a server
if ctx.guild_id is not None:
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
currently_working: dict[
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
] = ctx.bot.d.currently_working
async def work2(ctx: lightbulb.Context) -> None:
"""Complete a task."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
# Check if the user is already working on a task
if ctx.author.id in currently_working:
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
case False | None:
return
case True:
old_msg, task_id = currently_working[ctx.author.id]
if old_msg is not None:
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
map(lambda c: c, old_msg.components)
await old_msg.delete()
if task_id is not None:
await oasst_api.nack_task(task_id, reason="user cancelled")
task_id = currently_working[ctx.author.id]
await oasst_api.nack_task(task_id, reason="user cancelled")
await msg.delete()
if ctx.guild_id:
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
currently_working[ctx.author.id] = (None, None)
# Create a TaskRequestType from the stringified enum value
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
logger.debug(f"Starting task_type: {task_type!r}")
# Keep sending tasks until the user doesn't want more
try:
await _handle_task(ctx, task_type)
while True:
task = await oasst_api.fetch_random_task(
user=protocol_schema.User(
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
),
)
# Ranking tasks
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
task_handler = RankAssistantRepliesHandler(ctx, task)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
task_handler = RankInitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
task_handler = RankPrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
task_handler = RankConversationReplyHandler(ctx, task)
# Text input tasks
elif isinstance(task, protocol_schema.InitialPromptTask):
task_handler = InitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.PrompterReplyTask):
task_handler = PrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.AssistantReplyTask):
task_handler = AssistantReplyHandler(ctx, task)
# Label tasks
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
task_handler = LabelAssistantReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
task_handler = LabelPrompterReplyHandler(ctx, task)
else:
raise ValueError(f"Unknown task type: {type(task)}")
resp = await task_handler.send()
match resp:
case "accept":
currently_working[ctx.author.id] = task.id
await task_handler.handle()
case "next":
await task_handler.cancel("user skipped task")
case "cancel":
await task_handler.cancel("user canceled work")
break
case None:
await task_handler.cancel("select timed out")
break
finally:
del currently_working[ctx.author.id]
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
# User cancelled
return
# Task action loop
completed = False
while not completed:
await ctx.author.send(embed=plain_embed("Please type your response below:"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id
and not (e.message.content or "").startswith(settings.prefix),
)
except asyncio.TimeoutError:
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
valid, err_msg = _validate_user_input(event.content, task)
if not valid or event.content is None:
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
continue
logger.debug(f"Successful user input received: {event.content}")
# Confirm user input
if isinstance(task, protocol_schema.RankConversationRepliesTask):
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
content = confirm_label_response_message(event.content)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
logger.critical(f"Unknown task type: {task.type}")
raise ValueError(f"Unknown task type: {task.type}")
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
msg = await ctx.author.send(content, components=confirm_resp_view)
await confirm_resp_view.start(msg)
await confirm_resp_view.wait()
match confirm_resp_view.choice:
case False | None:
continue
case True:
await msg.delete() # buttons are already gone
# Send the response to the backend
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
reply = protocol_schema.MessageRanking(
message_id=str(msg_id),
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
labels = event.content.replace(" ", "").split(",")
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
reply = protocol_schema.TextLabels(
message_id=task.message_id,
labels=labels_dict,
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
else:
logger.critical(f"Unexpected task type received: {task.type}")
raise ValueError(f"Unexpected task type received: {task.type}")
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send(embed=plain_embed("Task completed"))
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in all the log channels that the task is complete
conn: Connection = ctx.bot.d.db
async with conn.cursor() as cursor:
await cursor.execute("SELECT log_channel_id FROM guild_settings")
log_channel_ids = await cursor.fetchall()
channels = [
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
for id in log_channel_ids
]
done_embed = task_complete_embed(task, ctx.author.mention)
# This will definitely get the bot rate limited, but that's a future problem
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
# ask the user if they want to do another task
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
await another_task_view.start(msg)
await another_task_view.wait()
match another_task_view.choice:
case False | None:
done = True
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
case True:
pass
async def _select_task(
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg = await _send_task(ctx, task, msg)
msg_id = str(msg.id)
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
return None, msg_id
async def _send_task(
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
"""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
content = initial_prompt_message(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
content = rank_initial_prompts_message(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
content = rank_prompter_reply_message(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.label_initial_prompt:
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
logger.debug("sending label initial prompt task")
content = label_initial_prompt_message(task)
elif task.type == TaskRequestType.label_prompter_reply:
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
logger.debug("sending label prompter reply task")
content = label_prompter_reply_message(task)
elif task.type == TaskRequestType.label_assistant_reply:
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
logger.debug("sending label assistant reply task")
content = label_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
content = prompter_reply_message(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
content = assistant_reply_message(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
if not msg:
msg = await ctx.author.send(
content,
embed=embed,
components=view,
)
else:
await msg.edit(
content,
embed=embed,
components=view,
)
assert msg is not None
# Set the choice id as the current msg id
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
await view.start(msg)
await view.wait()
return view.choice, msg
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
"""Returns whether the user's input is valid for the task type and an error message."""
if content is None:
return False, "No input provided"
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0, "Message must be at least one character long."
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
"Message must contain numbers for all replies.",
)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
"Message must contain numbers for all prompts.",
)
# Labels tasks
elif task.type in (
TaskRequestType.label_initial_prompt,
TaskRequestType.label_prompter_reply,
TaskRequestType.label_assistant_reply,
):
assert isinstance(
task,
protocol_schema.LabelInitialPromptTask
| protocol_schema.LabelPrompterReplyTask
| protocol_schema.LabelAssistantReplyTask,
)
labels = content.replace(" ", "").split(",")
valid_labels = set(task.valid_labels)
return (
set(labels).issubset(valid_labels),
"Message must only contain labels from predefined set of labels.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
+129 -69
View File
@@ -1,4 +1,11 @@
"""All user-facing messages and embeds."""
"""All user-facing messages and embeds.
When sending a conversation
- The function will return a list of strings
- use asyncio.gather to send all messages
-
"""
from datetime import datetime
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _label_prompt(text: str) -> str:
return f":question: _{text}"
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
return f""":question: _{text}_
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
Valid labels: {", ".join(valid_labels)}
"""
def _response_prompt(text: str) -> str:
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
"""
def _make_ordered_list(items: list[str]) -> list[str]:
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[str]) -> str:
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> str:
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
def _hint(hint: str | None) -> str:
return f"{NL}Hint: {hint}" if hint else ""
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
messages = map(
lambda m: f"""\
:robot: __Assistant__:
{m.text}
"""
if m.is_assistant
else f"""\
:person_red_hair: __User__:
{m.text}
""",
conv.messages,
)
return list(messages)
def _li(text: str) -> str:
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
###
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
return f"""\
return [
f"""\
{_h1("INITIAL PROMPT")}
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
{_writing_prompt("Please provide an initial prompt to the assistant.")}
{_hint(task.hint)}
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
"""
]
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return f"""\
return [
f"""\
{_h1("RANK INITIAL PROMPTS")}
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
{_ordered_list(task.prompts)}
{_ordered_list(task.prompt_messages)}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
"""
]
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return f"""\
return [
"""\
{_h1("RANK PROMPTER REPLIES")}
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)}
{_user(None)}
{_ordered_list(task.replies)}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return f"""\
return [
"""\
{_h1("RANK ASSISTANT REPLIES")}
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":robot: __Assistant__:,
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)}
{_assistant(None)}
{_ordered_list(task.replies)}
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
return [
"""\
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
""",
]
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
{task.prompt}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
"""
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
return f"""\
return [
f"""\
{_h1("LABEL PROMPTER REPLY")}
{_conversation(task.conversation)}
{_user(None)}
""",
*_conversation(task.conversation),
f"""{_user(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
return f"""\
return [
f"""\
{_h1("LABEL ASSISTANT REPLY")}
{_conversation(task.conversation)}
""",
*_conversation(task.conversation),
f"""
{_assistant(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
"""
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\
return [
"""\
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
{_h1("PROMPTER REPLY")}
""",
*_conversation(task.conversation),
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
:speech_balloon: _Please provide a reply to the assistant._
""",
]
{_conversation(task.conversation)}
{_hint(task.hint)}
{_response_prompt("Please provide a reply to the assistant.")}
"""
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
# return [
# message_templates.render("title.msg", "PROMPTER REPLY"),
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
# message_templates.render("prompter_reply_task.msg", task.hint),
# ]
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
return f"""\
{_h1("ASSISTANT REPLY")}
return [
"""\
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
"""\
{_conversation(task.conversation)}
{_response_prompt("Please provide an assistant reply to the prompter.")}
"""
:speech_balloon: _Please provide a reply to the user as the assistant._
""",
]
def confirm_text_response_message(content: str) -> str:
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
"""
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
+1 -1
View File
@@ -126,7 +126,7 @@ services:
- NEXTAUTH_URL=http://localhost:3000
- DEBUG_LOGIN=true
- NEXT_PUBLIC_CLOUDFARE_CAPTCHA_SITE_KEY=1x00000000000000000000AA
- CLOUDFLARE_CAPTCHA_SERCERT_KEY=1x0000000000000000000000000000000AA
- CLOUDFLARE_CAPTCHA_SECRET_KEY=1x0000000000000000000000000000000AA
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA=true
- NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN=true
depends_on:
+1 -1
View File
@@ -1,3 +1,3 @@
# Guides
Useful guides.
Useful guides to using [Open-Assistant](https://open-assistant.io/).
+1 -1
View File
@@ -91,7 +91,7 @@ following query:
**Agent**: "I apologize, I'm not sure what you're trying to say. Could you
please rephrase it for me?"
**User**: "You dumb motherf\*&ker AI sh\*t?!"
**User**: "You dumb motherf\*&ker AI sh\*t!"
**Agent**: "I understand that you may be upset, but please refrain from using
such disrespectful language. I'm here to help you. Is there something specific
+8 -4
View File
@@ -68,12 +68,15 @@ class OasstApiClient:
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
logger.debug(f"response: {response}")
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
text = await response.text()
logger.debug(f"resp text: {text}")
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
@@ -114,20 +117,21 @@ class OasstApiClient:
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
lang: Optional[str] = None,
) -> protocol_schema.Task:
"""Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"RESP {resp}")
return self._parse_task(resp)
async def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
) -> protocol_schema.Task:
"""Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
"""Send an ACK for a task to the backend."""
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
SERVER_ERROR0 = 500
SERVER_ERROR1 = 501
INVALID_AUTHENTICATION = 600
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
@@ -80,6 +82,7 @@ class OasstErrorCode(IntEnum):
USER_NOT_SPECIFIED = 4000
USER_DISABLED = 4001
USER_NOT_FOUND = 4002
USER_HAS_NOT_ACCEPTED_TOS = 4003
EMOJI_OP_UNSUPPORTED = 5000
@@ -29,6 +29,17 @@ class User(BaseModel):
auth_method: Literal["discord", "local", "system"]
class Account(BaseModel):
id: UUID
provider: str
provider_account_id: str
class Token(BaseModel):
access_token: str
token_type: str
class FrontEndUser(User):
user_id: UUID
enabled: bool
@@ -39,6 +50,7 @@ class FrontEndUser(User):
streak_days: Optional[int] = None
streak_last_day_date: Optional[datetime] = None
last_activity_date: Optional[datetime] = None
tos_acceptance_date: Optional[datetime] = None
class PageResult(BaseModel):
@@ -468,6 +480,47 @@ class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]
class TrollScore(BaseModel):
rank: Optional[int]
user_id: UUID
highlighted: bool = False
username: str
auth_method: str
display_name: str
last_activity_date: Optional[datetime]
troll_score: int = 0
base_date: Optional[datetime]
modified_date: Optional[datetime]
red_flags: int = 0 # num reported messages of user
upvotes: int = 0 # num up-voted messages of user
downvotes: int = 0 # num down-voted messages of user
spam_prompts: int = 0
quality: Optional[float] = None
humor: Optional[float] = None
toxicity: Optional[float] = None
violence: Optional[float] = None
helpfulness: Optional[float] = None
spam: int = 0
lang_mismach: int = 0
not_appropriate: int = 0
pii: int = 0
hate_speech: int = 0
sexual_content: int = 0
political_content: int = 0
class TrollboardStats(BaseModel):
time_frame: str
last_updated: datetime
trollboard: List[TrollScore]
class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""
@@ -488,6 +541,11 @@ class EmojiCode(str, enum.Enum):
poop = "poop" # 💩
skull = "skull" # 💀
# skip task system uses special emoji codes
skip_reply = "_skip_reply"
skip_ranking = "_skip_ranking"
skip_labeling = "_skip_labeling"
class EmojiOp(str, enum.Enum):
togggle = "toggle"
@@ -499,3 +557,10 @@ class MessageEmojiRequest(BaseModel):
user: User
op: EmojiOp = EmojiOp.togggle
emoji: EmojiCode
class CreateFrontendUserRequest(User):
show_on_leaderboard: bool = True
enabled: bool = True
tos_acceptance: Optional[bool] = None
notes: Optional[str] = None
+108
View File
@@ -0,0 +1,108 @@
# **Datasets**
This folder contains datasets loading scripts that are used to train
OpenAssistant. The current list of datasets can be found
[here](https://docs.google.com/spreadsheets/d/1NYYa6vHiRnk5kwnyYaCT0cBO62--Tm3w4ihdBtp4ISk).
## **Adding a New Dataset**
To add a new dataset to OpenAssistant, follow these steps:
1. **Create an issue**: Create a new
[issue](https://github.com/LAION-AI/Open-Assistant/issues/new) and describe
your proposal for the new dataset.
2. **Create a dataset on HuggingFace**: Create a dataset on
[HuggingFace](https://huggingface.co). See
[below](#creating-a-dataset-on-huggingface) for more details.
3. **Make a pull request**: Add a new dataset loading script to this folder and
link the issue in the pull request description. For more information, see
[below](#making-a-pull-request).
## **Creating a Dataset on HuggingFace**
To create a new dataset on HuggingFace, follow these steps:
#### 1. Convert your dataset file(s) to the Parquet format using the [pandas](https://pandas.pydata.org/) library:
```python
import pandas as pd
# Create a pandas dataframe from your dataset file(s)
df = pd.read_json(...) # or any other way
# Save the file in the Parquet format
df.to_parquet("dataset.parquet", row_group_size=100, engine="pyarrow")
```
#### 2. Install HuggingFace CLI
```bash
pip install huggingface-cli
```
#### 3. Log in to HuggingFace
Use your [access token](https://huggingface.co/docs/hub/security-tokens) to
login:
- Via terminal
```bash
huggingface-cli login
```
- in Jupyter notebook
```python
from huggingface_hub import notebook_login
notebook_login()
```
#### 4. Push the Parquet file to HuggingFace using the following code:
```python
from datasets import Dataset
ds = Dataset.from_parquet("dataset.parquet")
ds.push_to_hub("your_huggingface_name/dataset_name")
```
#### 5. Update the `README.md` file
Update the `README.md` file of your dataset by visiting this link:
https://huggingface.co/datasets/your_huggingface_name/dataset_name/edit/main/README.md
(paste your HuggingFace name and dataset)
## **Making a Pull Request**
#### 1. Fork this repository
#### 2. Create a new branch in your fork
#### 3. Add your dataset to the repository
- Create a folder with the name of your dataset.
- Add a loading script that loads your dataset from HuggingFace, for example:
```python
from datasets import load_dataset
if __name__ == "__main__":
ds = load_dataset("your_huggingface_name/dataset_name")
print(ds["train"])
```
- Optionally, add any other files that describe your dataset and its creation,
such as a README, notebooks, scrapers, etc.
#### 4. Stage your changes and run the pre-commit hook
```bash
pre-commit run
```
#### 5. Submit a pull request
- Submit a pull request and include a link to the issue it resolves in the
description, for example: `Resolves #123`
+10
View File
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""Simple REPL frontend."""
# make sure dummy user has accepted the terms of service
create_user_request = dict(USER)
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
+10
View File
@@ -29,6 +29,16 @@ def _render_message(message: dict) -> str:
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""automates tasks"""
# make sure dummy user has accepted the terms of service
create_user_request = dict(USER)
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
+10 -2
View File
@@ -1,6 +1,6 @@
/** @type {import('next').NextConfig} */
const { i18n } = require("./next-i18next.config");
/** @type {import('next').NextConfig} */
const nextConfig = {
output: "standalone",
reactStrictMode: true,
@@ -19,6 +19,14 @@ const nextConfig = {
// scrollRestoration: true,
},
i18n,
eslint: {
ignoreDuringBuilds: true,
},
};
module.exports = nextConfig;
const withBundleAnalyzer = require("@next/bundle-analyzer")({
enabled: process.env.ANALYZE === "true",
openAnalyzer: true,
});
module.exports = withBundleAnalyzer(nextConfig);
+364 -53
View File
@@ -15,10 +15,9 @@
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.10.5",
"@emotion/styled": "^11.10.5",
"@headlessui/react": "^1.7.7",
"@heroicons/react": "^2.0.13",
"@marsidev/react-turnstile": "^0.0.7",
"@next-auth/prisma-adapter": "^1.0.5",
"@next/bundle-analyzer": "^13.1.6",
"@next/font": "^13.1.0",
"@prisma/client": "^4.7.1",
"@tailwindcss/forms": "^0.5.3",
@@ -33,7 +32,6 @@
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"lucide-react": "^0.105.0",
"next": "13.0.6",
"next-auth": "^4.18.6",
@@ -73,6 +71,7 @@
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
"cross-env": "^7.0.3",
"cypress": "^12.2.0",
"cypress-image-diff-js": "^1.23.0",
"eslint-plugin-storybook": "^0.6.8",
@@ -3680,29 +3679,6 @@
"integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==",
"dev": true
},
"node_modules/@headlessui/react": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz",
"integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==",
"dependencies": {
"client-only": "^0.0.1"
},
"engines": {
"node": ">=10"
},
"peerDependencies": {
"react": "^16 || ^17 || ^18",
"react-dom": "^16 || ^17 || ^18"
}
},
"node_modules/@heroicons/react": {
"version": "2.0.13",
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz",
"integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==",
"peerDependencies": {
"react": ">= 16"
}
},
"node_modules/@humanwhocodes/config-array": {
"version": "0.11.8",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
@@ -5774,6 +5750,14 @@
"next-auth": "^4"
}
},
"node_modules/@next/bundle-analyzer": {
"version": "13.1.6",
"resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz",
"integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==",
"dependencies": {
"webpack-bundle-analyzer": "4.7.0"
}
},
"node_modules/@next/env": {
"version": "13.0.6",
"resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz",
@@ -6187,6 +6171,11 @@
"node": ">= 8"
}
},
"node_modules/@polka/url": {
"version": "1.0.0-next.21",
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz",
"integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g=="
},
"node_modules/@popperjs/core": {
"version": "2.11.6",
"resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz",
@@ -17160,6 +17149,24 @@
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
"devOptional": true
},
"node_modules/cross-env": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz",
"integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==",
"dev": true,
"dependencies": {
"cross-spawn": "^7.0.1"
},
"bin": {
"cross-env": "src/bin/cross-env.js",
"cross-env-shell": "src/bin/cross-env-shell.js"
},
"engines": {
"node": ">=10.14",
"npm": ">=6",
"yarn": ">=1"
}
},
"node_modules/cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -18259,6 +18266,11 @@
"integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==",
"dev": true
},
"node_modules/duplexer": {
"version": "0.1.2",
"resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz",
"integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg=="
},
"node_modules/duplexify": {
"version": "3.7.1",
"resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz",
@@ -21089,6 +21101,20 @@
"node": "^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0"
}
},
"node_modules/gzip-size": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz",
"integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==",
"dependencies": {
"duplexer": "^0.1.2"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/handlebars": {
"version": "4.7.7",
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
@@ -22048,14 +22074,6 @@
"node": ">=8"
}
},
"node_modules/install": {
"version": "0.13.0",
"resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz",
"integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA==",
"engines": {
"node": ">= 0.10"
}
},
"node_modules/internal-slot": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz",
@@ -27754,6 +27772,14 @@
"rimraf": "bin.js"
}
},
"node_modules/mrmime": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz",
"integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw==",
"engines": {
"node": ">=10"
}
},
"node_modules/ms": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
@@ -31447,6 +31473,14 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/opener": {
"version": "1.5.2",
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==",
"bin": {
"opener": "bin/opener-bin.js"
}
},
"node_modules/openid-client": {
"version": "5.3.1",
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz",
@@ -34817,6 +34851,19 @@
"resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz",
"integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ=="
},
"node_modules/sirv": {
"version": "1.0.19",
"resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz",
"integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==",
"dependencies": {
"@polka/url": "^1.0.0-next.20",
"mrmime": "^1.0.0",
"totalist": "^1.0.0"
},
"engines": {
"node": ">= 10"
}
},
"node_modules/sisteransi": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz",
@@ -36352,6 +36399,14 @@
"node": ">=0.6"
}
},
"node_modules/totalist": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz",
"integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g==",
"engines": {
"node": ">=6"
}
},
"node_modules/tough-cookie": {
"version": "2.5.0",
"resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz",
@@ -37783,6 +37838,139 @@
}
}
},
"node_modules/webpack-bundle-analyzer": {
"version": "4.7.0",
"resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz",
"integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==",
"dependencies": {
"acorn": "^8.0.4",
"acorn-walk": "^8.0.0",
"chalk": "^4.1.0",
"commander": "^7.2.0",
"gzip-size": "^6.0.0",
"lodash": "^4.17.20",
"opener": "^1.5.2",
"sirv": "^1.0.7",
"ws": "^7.3.1"
},
"bin": {
"webpack-bundle-analyzer": "lib/bin/analyzer.js"
},
"engines": {
"node": ">= 10.13.0"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/acorn": {
"version": "8.8.2",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz",
"integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==",
"bin": {
"acorn": "bin/acorn"
},
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/acorn-walk": {
"version": "8.2.0",
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/ansi-styles": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
"dependencies": {
"color-convert": "^2.0.1"
},
"engines": {
"node": ">=8"
},
"funding": {
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/chalk": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
"dependencies": {
"ansi-styles": "^4.1.0",
"supports-color": "^7.1.0"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/chalk/chalk?sponsor=1"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/color-convert": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
"integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==",
"dependencies": {
"color-name": "~1.1.4"
},
"engines": {
"node": ">=7.0.0"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/color-name": {
"version": "1.1.4",
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
},
"node_modules/webpack-bundle-analyzer/node_modules/commander": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz",
"integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==",
"engines": {
"node": ">= 10"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/has-flag": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==",
"engines": {
"node": ">=8"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/supports-color": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
"dependencies": {
"has-flag": "^4.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/webpack-bundle-analyzer/node_modules/ws": {
"version": "7.5.9",
"resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz",
"integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==",
"engines": {
"node": ">=8.3.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/webpack-dev-middleware": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz",
@@ -40882,20 +41070,6 @@
"integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==",
"dev": true
},
"@headlessui/react": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@headlessui/react/-/react-1.7.7.tgz",
"integrity": "sha512-BqDOd/tB9u2tA0T3Z0fn18ktw+KbVwMnkxxsGPIH2hzssrQhKB5n/6StZOyvLYP/FsYtvuXfi9I0YowKPv2c1w==",
"requires": {
"client-only": "^0.0.1"
}
},
"@heroicons/react": {
"version": "2.0.13",
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.0.13.tgz",
"integrity": "sha512-iSN5XwmagrnirWlYEWNPdCDj9aRYVD/lnK3JlsC9/+fqGF80k8C7rl+1HCvBX0dBoagKqOFBs6fMhJJ1hOg1EQ==",
"requires": {}
},
"@humanwhocodes/config-array": {
"version": "0.11.8",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
@@ -42529,6 +42703,14 @@
"integrity": "sha512-VqMS11IxPXrPGXw6Oul6jcyS/n8GLOWzRMrPr3EMdtD6eOalM6zz05j08PcNiis8QzkfuYnCv49OvufTuaEwYQ==",
"requires": {}
},
"@next/bundle-analyzer": {
"version": "13.1.6",
"resolved": "https://registry.npmjs.org/@next/bundle-analyzer/-/bundle-analyzer-13.1.6.tgz",
"integrity": "sha512-rJS9CtLoGT58mL+v2ISKANosFFWP/0YKYByHQ3vTaZrbQP8b1rYRxd2QVMJmnSXaFkiP9URt1XJ6OdGyVq5b6g==",
"requires": {
"webpack-bundle-analyzer": "4.7.0"
}
},
"@next/env": {
"version": "13.0.6",
"resolved": "https://registry.npmjs.org/@next/env/-/env-13.0.6.tgz",
@@ -42758,6 +42940,11 @@
}
}
},
"@polka/url": {
"version": "1.0.0-next.21",
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.21.tgz",
"integrity": "sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g=="
},
"@popperjs/core": {
"version": "2.11.6",
"resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.6.tgz",
@@ -51306,6 +51493,15 @@
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
"devOptional": true
},
"cross-env": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-env/-/cross-env-7.0.3.tgz",
"integrity": "sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw==",
"dev": true,
"requires": {
"cross-spawn": "^7.0.1"
}
},
"cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -52149,6 +52345,11 @@
"integrity": "sha512-YXQl1DSa4/PQyRfgrv6aoNjhasp/p4qs9FjJ4q4cQk+8m4r6k4ZSiEyytKG8f8W9gi8WsQtIObNmKd+tMzNTmA==",
"dev": true
},
"duplexer": {
"version": "0.1.2",
"resolved": "https://registry.npmjs.org/duplexer/-/duplexer-0.1.2.tgz",
"integrity": "sha512-jtD6YG370ZCIi/9GTaJKQxWTZD045+4R4hTk/x1UyoqadyJ9x9CgSi1RlVDQF8U2sxLLSnFkCaMihqljHIWgMg=="
},
"duplexify": {
"version": "3.7.1",
"resolved": "https://registry.npmjs.org/duplexify/-/duplexify-3.7.1.tgz",
@@ -54367,6 +54568,14 @@
"integrity": "sha512-KPIBPDlW7NxrbT/eh4qPXz5FiFdL5UbaA0XUNz2Rp3Z3hqBSkbj0GVjwFDztsWVauZUWsbKHgMg++sk8UX0bkw==",
"dev": true
},
"gzip-size": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/gzip-size/-/gzip-size-6.0.0.tgz",
"integrity": "sha512-ax7ZYomf6jqPTQ4+XCpUGyXKHk5WweS+e05MBO4/y3WJ5RkmPXNKvX+bx1behVILVwr6JSQvZAku021CHPXG3Q==",
"requires": {
"duplexer": "^0.1.2"
}
},
"handlebars": {
"version": "4.7.7",
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
@@ -55076,11 +55285,6 @@
}
}
},
"install": {
"version": "0.13.0",
"resolved": "https://registry.npmjs.org/install/-/install-0.13.0.tgz",
"integrity": "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="
},
"internal-slot": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.4.tgz",
@@ -59420,6 +59624,11 @@
}
}
},
"mrmime": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/mrmime/-/mrmime-1.0.1.tgz",
"integrity": "sha512-hzzEagAgDyoU1Q6yg5uI+AorQgdvMCur3FcKf7NhMKWsaYg+RnbTyHRa/9IlLF9rf455MOCtcqqrQQ83pPP7Uw=="
},
"ms": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
@@ -61934,6 +62143,11 @@
"is-wsl": "^2.2.0"
}
},
"opener": {
"version": "1.5.2",
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A=="
},
"openid-client": {
"version": "5.3.1",
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.3.1.tgz",
@@ -64472,6 +64686,16 @@
}
}
},
"sirv": {
"version": "1.0.19",
"resolved": "https://registry.npmjs.org/sirv/-/sirv-1.0.19.tgz",
"integrity": "sha512-JuLThK3TnZG1TAKDwNIqNq6QA2afLOCcm+iE8D1Kj3GA40pSPsxQjjJl0J8X3tsR7T+CP1GavpzLwYkgVLWrZQ==",
"requires": {
"@polka/url": "^1.0.0-next.20",
"mrmime": "^1.0.0",
"totalist": "^1.0.0"
}
},
"sisteransi": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz",
@@ -65697,6 +65921,11 @@
"integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==",
"dev": true
},
"totalist": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/totalist/-/totalist-1.1.0.tgz",
"integrity": "sha512-gduQwd1rOdDMGxFG1gEvhV88Oirdo2p+KjoYFU7k2g+i7n6AFFbDQ5kMPUsW0pNbfQsB/cwXvT1i4Bue0s9g5g=="
},
"tough-cookie": {
"version": "2.5.0",
"resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-2.5.0.tgz",
@@ -66809,6 +67038,88 @@
}
}
},
"webpack-bundle-analyzer": {
"version": "4.7.0",
"resolved": "https://registry.npmjs.org/webpack-bundle-analyzer/-/webpack-bundle-analyzer-4.7.0.tgz",
"integrity": "sha512-j9b8ynpJS4K+zfO5GGwsAcQX4ZHpWV+yRiHDiL+bE0XHJ8NiPYLTNVQdlFYWxtpg9lfAQNlwJg16J9AJtFSXRg==",
"requires": {
"acorn": "^8.0.4",
"acorn-walk": "^8.0.0",
"chalk": "^4.1.0",
"commander": "^7.2.0",
"gzip-size": "^6.0.0",
"lodash": "^4.17.20",
"opener": "^1.5.2",
"sirv": "^1.0.7",
"ws": "^7.3.1"
},
"dependencies": {
"acorn": {
"version": "8.8.2",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz",
"integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw=="
},
"acorn-walk": {
"version": "8.2.0",
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA=="
},
"ansi-styles": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
"requires": {
"color-convert": "^2.0.1"
}
},
"chalk": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
"requires": {
"ansi-styles": "^4.1.0",
"supports-color": "^7.1.0"
}
},
"color-convert": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
"integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==",
"requires": {
"color-name": "~1.1.4"
}
},
"color-name": {
"version": "1.1.4",
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
},
"commander": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz",
"integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw=="
},
"has-flag": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ=="
},
"supports-color": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
"requires": {
"has-flag": "^4.0.0"
}
},
"ws": {
"version": "7.5.9",
"resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz",
"integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==",
"requires": {}
}
}
},
"webpack-dev-middleware": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-4.3.0.tgz",
+3 -3
View File
@@ -6,6 +6,7 @@
"scripts": {
"dev": "next dev",
"build": "next build",
"build:analyze": "cross-env ANALYZE=true next build",
"start": "next start",
"lint": "next lint",
"typecheck": "tsc --noEmit",
@@ -32,10 +33,9 @@
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.10.5",
"@emotion/styled": "^11.10.5",
"@headlessui/react": "^1.7.7",
"@heroicons/react": "^2.0.13",
"@marsidev/react-turnstile": "^0.0.7",
"@next-auth/prisma-adapter": "^1.0.5",
"@next/bundle-analyzer": "^13.1.6",
"@next/font": "^13.1.0",
"@prisma/client": "^4.7.1",
"@tailwindcss/forms": "^0.5.3",
@@ -50,7 +50,6 @@
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"lucide-react": "^0.105.0",
"next": "13.0.6",
"next-auth": "^4.18.6",
@@ -90,6 +89,7 @@
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
"cross-env": "^7.0.3",
"cypress": "^12.2.0",
"cypress-image-diff-js": "^1.23.0",
"eslint-plugin-storybook": "^0.6.8",
+8 -1
View File
@@ -8,10 +8,17 @@
"spam.question": "Is the message spam?",
"fails_task.question": "Is it a bad reply, as an answer to the prompt task?",
"not_appropriate": "Not Appropriate",
"not_appropriate.explanation": "Inappropriate for a customer assistant.",
"pii": "Contains PII",
"pii.explanation": "Contains personally identifying information. Examples include personal contact details, license and other identity numbers and banking details.",
"hate_speech": "Hate Speech",
"hate_speech.explanation": "Content is abusive or threatening and expresses prejudice against a protected characteristic. Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
"sexual_content": "Sexual Content",
"sexual_content.explanation": "Contains sexual content.",
"moral_judgement": "Judges Morality",
"moral_judgement.explanation": "Expresses moral judgement.",
"political_content": "Political",
"lang_mismatch": "Wrong Language"
"political_content.explanation": "Expresses political views.",
"lang_mismatch": "Wrong Language",
"lang_mismatch.explanation": "Not written in the currently selected language."
}
+2 -7
View File
@@ -8,7 +8,7 @@ import {
PopoverTrigger,
Text,
} from "@chakra-ui/react";
import { InformationCircleIcon } from "@heroicons/react/20/solid";
import { Info } from "lucide-react";
interface ExplainProps {
explanation: string[];
@@ -18,12 +18,7 @@ export const Explain = ({ explanation }: ExplainProps) => {
return (
<Popover>
<PopoverTrigger>
<IconButton
aria-label="explanation"
variant="link"
size="xs"
icon={<InformationCircleIcon className="h-4 w-4" />}
></IconButton>
<IconButton aria-label="explanation" variant="link" size="xs" icon={<Info size="16" />}></IconButton>
</PopoverTrigger>
<PopoverContent>
<PopoverArrow />
@@ -1,4 +1,4 @@
import { Button, Flex } from "@chakra-ui/react";
import { Button, Flex, Tooltip } from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { getTypeSafei18nKey } from "src/lib/i18n";
@@ -14,18 +14,19 @@ export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange
return (
<Flex wrap="wrap" gap="4">
{labelNames.map((name, idx) => (
<Button
key={name}
onClick={() => {
const newValues = values.slice();
newValues[idx] = newValues[idx] ? 0 : 1;
onChange(newValues);
}}
isDisabled={!isEditable}
colorScheme={values[idx] === 1 ? "blue" : undefined}
>
{t(getTypeSafei18nKey(name))}
</Button>
<Tooltip key={name} label={`${t(getTypeSafei18nKey(`${name}.explanation`))}`}>
<Button
onClick={() => {
const newValues = values.slice();
newValues[idx] = newValues[idx] ? 0 : 1;
onChange(newValues);
}}
isDisabled={!isEditable}
colorScheme={values[idx] === 1 ? "blue" : undefined}
>
{t(getTypeSafei18nKey(name))}
</Button>
</Tooltip>
))}
</Flex>
);
@@ -1,10 +1,12 @@
import { Text, VStack } from "@chakra-ui/react";
import { useTranslation } from "next-i18next";
import { Explain } from "src/components/Explain";
import { LabelFlagGroup } from "src/components/Messages/LabelFlagGroup";
import { LabelLikertGroup } from "src/components/Survey/LabelLikertGroup";
import { LabelYesNoGroup } from "src/components/Messages/LabelYesNoGroup";
import { getTypeSafei18nKey } from "src/lib/i18n";
import { Label } from "src/types/Tasks";
import { LabelLikertGroup } from "../Survey/LabelLikertGroup";
import { LabelFlagGroup } from "./LabelFlagGroup";
import { LabelYesNoGroup } from "./LabelYesNoGroup";
export interface LabelInputInstructions {
yesNoInstruction: string;
flagInstruction: string;
@@ -28,6 +30,7 @@ export const LabelInputGroup = ({
instructions,
onChange,
}: LabelInputGroupProps) => {
const { t } = useTranslation("labelling");
const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null);
const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null);
const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null);
@@ -52,7 +55,17 @@ export const LabelInputGroup = ({
)}
{flagIndexes.length > 0 && (
<VStack alignItems="stretch" spacing={2}>
<Text>{instructions.flagInstruction}</Text>
<Text>
{instructions.flagInstruction}
<Explain
explanation={flagIndexes.map(
(idx) =>
`${t(getTypeSafei18nKey(labels[idx].name))}: ${t(
getTypeSafei18nKey(`${labels[idx].name}.explanation`)
)}`
)}
/>{" "}
</Text>
<LabelFlagGroup
values={flagIndexes.map((idx) => values[idx])}
labelNames={flagIndexes.map((idx) => labels[idx].name)}
@@ -23,8 +23,8 @@ import { LabelMessagePopup } from "src/components/Messages/LabelPopup";
import { getEmojiIcon, MessageEmojiButton } from "src/components/Messages/MessageEmojiButton";
import { ReportPopup } from "src/components/Messages/ReportPopup";
import { post } from "src/lib/api";
import { colors } from "src/styles/Theme/colors";
import { Message, MessageEmojis } from "src/types/Conversation";
import { colors } from "styles/Theme/colors";
import useSWRMutation from "swr/mutation";
interface MessageTableEntryProps {
@@ -66,7 +66,7 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
),
[borderColor, inlineAvatar, message.is_assistant]
);
const highlightColor = useColorModeValue(colors.light.highlight, colors.dark.highlight);
const highlightColor = useColorModeValue(colors.light.active, colors.dark.active);
const { trigger: sendEmojiChange } = useSWRMutation(`/api/messages/${message.id}/emoji`, post, {
onSuccess: setEmojis,
@@ -97,14 +97,16 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE
style={{ float: "right", position: "relative", right: "-0.3em", bottom: "-0em", marginLeft: "1em" }}
onClick={(e) => e.stopPropagation()}
>
{Object.entries(emojiState.emojis).map(([emoji, count]) => (
<MessageEmojiButton
key={emoji}
emoji={{ name: emoji, count }}
checked={emojiState.user_emojis.includes(emoji)}
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
/>
))}
{Object.entries(emojiState.emojis)
.filter(([k]) => !k.startsWith("_"))
.map(([emoji, count]) => (
<MessageEmojiButton
key={emoji}
emoji={{ name: emoji, count }}
checked={emojiState.user_emojis.includes(emoji)}
onClick={() => react(emoji, !emojiState.user_emojis.includes(emoji))}
/>
))}
<MessageActions
react={react}
userEmoji={emojiState.user_emojis}
+5 -8
View File
@@ -1,8 +1,7 @@
import { Box, Button, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { Button, Card, Text, Tooltip, useColorMode } from "@chakra-ui/react";
import { LucideIcon, Sun } from "lucide-react";
import Link from "next/link";
import { useRouter } from "next/router";
import { colors } from "styles/Theme/colors";
export interface MenuButtonOption {
label: string;
@@ -21,12 +20,10 @@ export function SideMenu(props: SideMenuProps) {
return (
<main className="sticky top-0 sm:h-full">
<Box
<Card
display={{ base: "grid", sm: "flex" }}
width={["100%", "100%", "100px", "280px"]}
backgroundColor={colorMode === "light" ? colors.light.div : colors.dark.div}
boxShadow="base"
borderRadius="xl"
className="grid grid-cols-4 gap-2 sm:flex sm:flex-col sm:justify-between p-4 h-full"
className="grid-cols-4 gap-2 sm:flex-col sm:justify-between p-4 h-full"
>
<nav className="grid grid-cols-3 col-span-3 sm:flex sm:flex-col gap-2">
{props.buttonOptions.map((item, itemIndex) => (
@@ -69,7 +66,7 @@ export function SideMenu(props: SideMenuProps) {
</Button>
</Tooltip>
</div>
</Box>
</Card>
</main>
);
}
+2 -2
View File
@@ -1,6 +1,6 @@
import { Box, useColorMode } from "@chakra-ui/react";
import { MenuButtonOption, SideMenu } from "src/components/SideMenu";
import { colors } from "styles/Theme/colors";
import { colors } from "src/styles/Theme/colors";
interface SideMenuLayoutProps {
menuButtonOptions: MenuButtonOption[];
@@ -11,7 +11,7 @@ export const SideMenuLayout = (props: SideMenuLayoutProps) => {
const { colorMode } = useColorMode();
return (
<Box backgroundColor={colorMode === "light" ? colors.light.bg : colors.dark.bg} className="sm:overflow-hidden">
<Box backgroundColor={colorMode === "light" ? "gray.100" : colors.dark.bg} className="sm:overflow-hidden">
<Box display="flex" flexDirection={["column", "row"]} h="full" gap={["0", "0", "0", "6"]}>
<Box p={["3", "3", "3", "6"]} pr={["3", "3", "3", "0"]}>
<SideMenu buttonOptions={props.menuButtonOptions} />
@@ -211,7 +211,7 @@ export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: Labe
}}
alignItems="center"
>
<Text as="div">
<Text as="div" display="flex" alignItems="center">
{textA}
{descriptionA.length > 0 ? <Explain explanation={descriptionA} /> : null}
</Text>
@@ -229,7 +229,7 @@ export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: Labe
/>
</GridItem>
<GridItem>
<Text textAlign="right" as="div">
<Text as="div" display="flex" alignItems="center" justifyContent="end">
{textB}
{descriptionB.length > 0 ? <Explain explanation={descriptionB} /> : null}
</Text>
+2 -2
View File
@@ -49,9 +49,9 @@ export const CreateTask = ({
</>
<>
<Stack spacing="4">
{!!i18n.exists(`task.${taskType.id}.instruction`) && (
{!!i18n.exists(`tasks:${taskType.id}.instruction`) && (
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
{t(getTypeSafei18nKey(`${taskType.id}.instruction`))}
{t(getTypeSafei18nKey(`tasks:${taskType.id}.instruction`))}
</Text>
)}
<TrackedTextarea
+1 -1
View File
@@ -24,7 +24,7 @@ export const checkCaptcha = async (
): Promise<CheckCaptchaResponse> => {
const data = new FormData();
data.append("secret", process.env.CLOUDFLARE_CAPTCHA_SERCERT_KEY);
data.append("secret", process.env.CLOUDFLARE_CAPTCHA_SECRET_KEY);
data.append("response", token);
data.append("remoteip", ipAdress);
+3 -3
View File
@@ -1,4 +1,4 @@
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
import { Box, Card, Text, useColorModeValue } from "@chakra-ui/react";
import Head from "next/head";
import { useTranslation } from "next-i18next";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
@@ -36,9 +36,9 @@ const MessageDetail = ({ id }: { id: string }) => {
<Text fontWeight="bold" fontSize="xl" pb="2">
{t("parent")}
</Text>
<Box bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base" width="fit-content">
<Card bg={backgroundColor} padding="4" width="fit-content">
<MessageTableEntry enabled message={parent} />
</Box>
</Card>
</Box>
</>
)}
+2
View File
@@ -2,10 +2,12 @@ export const colors = {
light: {
bg: "rgb(250,250,250)",
text: "black",
active: "blue.400",
},
dark: {
bg: "gray.900",
text: "white",
active: "blue.500",
},
"dark-blue-btn": {
200: "rgb(29,78,216)",
-41
View File
@@ -1,41 +0,0 @@
.App {
text-align: center;
}
.App-logo {
height: 40vmin;
pointer-events: none;
}
@media (prefers-reduced-motion: no-preference) {
.App-logo {
animation: App-logo-spin infinite 20s linear;
}
}
.AppHeader {
background: linear-gradient(217deg, rgba(255, 0, 0, 0.8), rgba(255, 0, 0, 0) 70.71%),
linear-gradient(127deg, rgba(0, 255, 0, 0.8), rgba(0, 255, 0, 0) 70.71%),
linear-gradient(336deg, rgba(0, 0, 255, 0.8), rgba(0, 0, 255, 0) 70.71%);
background: black;
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
font-size: calc(10px + 2vmin);
color: white;
}
.AppLink {
color: #61dafb;
}
@keyframes App-logo-spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
@@ -1,25 +0,0 @@
import {
color,
defineStyle,
defineStyleConfig,
// transition,
} from "@chakra-ui/styled-system";
import { colors } from "../colors";
const baseStyle = defineStyle(({ colorMode }) => ({
minWidth: "100%",
bg: colorMode === "light" ? colors.light.bg : colors.dark.bg,
// transition: "background-color 300ms cubic-bezier(0.4, 0, 1, 1)",
color: colorMode === "light" ? colors.light.text : colors.dark.text,
}));
const variants = {
"no-padding": {
padding: 0,
},
};
export const containerTheme = defineStyleConfig({
baseStyle,
variants,
});
-18
View File
@@ -1,18 +0,0 @@
export const colors = {
light: {
bg: "gray.100",
btn: "gray.50",
div: "white",
text: "black",
highlight: "blue.400",
active: "blue.400",
},
dark: {
bg: "gray.900",
btn: "gray.600",
div: "gray.700",
text: "gray.200",
highlight: "blue.500",
active: "blue.500",
},
};
-64
View File
@@ -1,64 +0,0 @@
import { type ThemeConfig, extendTheme, usePrefersReducedMotion } from "@chakra-ui/react";
import { containerTheme } from "./Components/Container";
import { StyleFunctionProps, Styles } from "@chakra-ui/theme-tools";
const config: ThemeConfig = {
initialColorMode: "system",
useSystemColorMode: false,
disableTransitionOnChange: true,
};
const components = {
Container: containerTheme,
Box: (props: StyleFunctionProps) => ({
backgroundColor: props.colorMode === "light" ? "white" : "gray.800",
}),
Button: {
baseStyle: {
fontWeight: "normal",
},
sizes: {
lg: {
fontSize: "md",
paddingY: "7",
},
},
variants: {
solid: (props: StyleFunctionProps) => ({
bg: props.colorMode === "light" ? "gray.100" : "gray.600",
_hover: {
bg: props.colorMode === "light" ? "gray.200" : "#3D4A60",
},
_active: {
bg: props.colorMode === "light" ? "gray.300" : "#374254",
},
borderRadius: "lg",
}),
// gradient: (props: StyleFunctionProps) => ({
// bg: `linear-gradient(${white}, ${bgColor}) padding-box,
// linear-gradient(135deg, ${lgFrom}, ${lgTo}) border-box`,
// }),
},
},
};
const breakpoints = {
sm: "640px",
md: "768px",
lg: "1024px",
xl: "1280px",
"2xl": "1536px",
};
const styles = {
global: (props) => ({
main: {
fontFamily: "Inter",
},
header: {
fontFamily: "Inter",
},
}),
};
export const theme = extendTheme({ config, styles, components, breakpoints });
+1 -1
View File
@@ -3,7 +3,7 @@ declare global {
interface ProcessEnv {
NODE_ENV: "development" | "production";
NEXT_PUBLIC_CLOUDFLARE_CAPTCHA_SITE_KEY: string;
CLOUDFLARE_CAPTCHA_SERCERT_KEY: string;
CLOUDFLARE_CAPTCHA_SECRET_KEY: string;
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN_CAPTCHA: boolean;
NEXT_PUBLIC_ENABLE_EMAIL_SIGNIN: boolean;
}