mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
merging with main
This commit is contained in:
+31
@@ -0,0 +1,31 @@
|
||||
"""add origin column to message_tree_state
|
||||
|
||||
Revision ID: 49d8445b4c90
|
||||
Revises: f856bf19d32b
|
||||
Create Date: 2023-01-28 11:57:45.580027
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "49d8445b4c90"
|
||||
down_revision = "f856bf19d32b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
|
||||
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("message_tree_state", "origin")
|
||||
op.drop_column("message", "model_name")
|
||||
op.drop_column("message", "synthetic")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,187 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
import oasst_backend.utils.database_utils as db_utils
|
||||
import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import create_api_client
|
||||
from oasst_backend.models import ApiClient, Message
|
||||
from oasst_backend.models.message_tree_state import MessageTreeState
|
||||
from oasst_backend.models.message_tree_state import State as TreeState
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
|
||||
from sqlmodel import Session
|
||||
|
||||
# well known id
|
||||
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
|
||||
|
||||
|
||||
class Importer:
|
||||
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
|
||||
self.db = db
|
||||
self.origin = origin
|
||||
self.model_name = model_name
|
||||
|
||||
# get import api client
|
||||
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
|
||||
if not api_client:
|
||||
api_client = create_api_client(
|
||||
session=db,
|
||||
description="API client used for importing data",
|
||||
frontend_type="import",
|
||||
force_id=IMPORT_API_CLIENT_ID,
|
||||
)
|
||||
|
||||
ur = UserRepository(db, api_client)
|
||||
self.import_user = ur.lookup_system_user(username="import")
|
||||
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
|
||||
self.api_client = api_client
|
||||
|
||||
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
|
||||
|
||||
def import_message(
|
||||
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
|
||||
) -> Message:
|
||||
payload = db_payload.MessagePayload(text=message.text)
|
||||
msg = Message(
|
||||
id=message.message_id,
|
||||
message_tree_id=message_tree_id,
|
||||
frontend_message_id=message.message_id,
|
||||
parent_id=parent_id,
|
||||
review_count=message.review_count or 0,
|
||||
lang=message.lang or "en",
|
||||
review_result=True,
|
||||
synthetic=message.synthetic if message.synthetic is not None else True,
|
||||
model_name=message.model_name or self.model_name,
|
||||
role=message.role,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
user_id=self.import_user.id,
|
||||
)
|
||||
self.db.add(msg)
|
||||
if message.replies:
|
||||
for r in message.replies:
|
||||
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
|
||||
self.db.flush()
|
||||
if parent_id is None:
|
||||
self.pr.update_children_counts(msg.id)
|
||||
self.db.refresh(msg)
|
||||
return msg
|
||||
|
||||
def import_tree(
|
||||
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
|
||||
) -> tuple[MessageTreeState, Message]:
|
||||
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
|
||||
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
|
||||
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
|
||||
active = state == TreeState.RANKING
|
||||
mts = MessageTreeState(
|
||||
message_tree_id=root_msg.id,
|
||||
goal_tree_size=0,
|
||||
max_depth=0,
|
||||
max_children_count=0,
|
||||
state=state,
|
||||
origin=self.origin,
|
||||
active=active,
|
||||
)
|
||||
self.db.add(mts)
|
||||
return mts, root_msg
|
||||
|
||||
|
||||
def import_file(
|
||||
input_file_path: Path,
|
||||
origin: str,
|
||||
*,
|
||||
model_name: Optional[str] = None,
|
||||
num_activate: int = 0,
|
||||
max_count: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> int:
|
||||
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
|
||||
def import_tx(db: Session) -> int:
|
||||
importer = Importer(db, origin=origin, model_name=model_name)
|
||||
i = 0
|
||||
with input_file_path.open() as file_in:
|
||||
# read line tree object
|
||||
for line in file_in:
|
||||
dict_tree = json.loads(line)
|
||||
|
||||
# validate data
|
||||
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
|
||||
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
|
||||
if existing_mts:
|
||||
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
|
||||
else:
|
||||
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
|
||||
mts, root_msg = importer.import_tree(tree, state=state)
|
||||
i += 1
|
||||
logger.info(
|
||||
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
|
||||
)
|
||||
|
||||
if max_count and i >= max_count:
|
||||
logger.info(f"Reached max count {max_count} of trees to import.")
|
||||
break
|
||||
return i
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN with rollback")
|
||||
return import_tx()
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_file_path",
|
||||
help="Input file path",
|
||||
)
|
||||
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
|
||||
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
|
||||
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
|
||||
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
|
||||
parser.add_argument("--dry_run", type=str2bool, default=False)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
input_file_path = Path(args.input_file_path)
|
||||
if not input_file_path.exists() or not input_file_path.is_file():
|
||||
print("Invalid input file:", args.input_file_path)
|
||||
exit(1)
|
||||
|
||||
dry_run = args.dry_run
|
||||
num_imported = import_file(
|
||||
input_file_path,
|
||||
origin=args.origin or input_file_path.name,
|
||||
model_name=args.model_name,
|
||||
num_activate=args.num_activate,
|
||||
max_count=args.max_count,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
logger.info(f"Done ({num_imported=}, {dry_run=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+3
-1
@@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
check_duplicate=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
@@ -215,7 +216,8 @@ def ensure_tree_states():
|
||||
try:
|
||||
logger.info("Startup: TreeManager.ensure_tree_states()")
|
||||
with Session(engine) as db:
|
||||
tm = TreeManager(db, None)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
|
||||
tm.ensure_tree_states()
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator, NamedTuple
|
||||
from typing import Generator, NamedTuple, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
@@ -67,6 +68,7 @@ def create_api_client(
|
||||
trusted: bool | None = False,
|
||||
admin_email: str | None = None,
|
||||
api_key: str | None = None,
|
||||
force_id: Optional[UUID] = None,
|
||||
) -> ApiClient:
|
||||
if api_key is None:
|
||||
api_key = token_hex(32)
|
||||
@@ -79,6 +81,8 @@ def create_api_client(
|
||||
trusted=trusted,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
if force_id:
|
||||
api_client.id = force_id
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
session.refresh(api_client)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -43,11 +49,28 @@ def label_text(
|
||||
|
||||
|
||||
@router.get("/valid_labels")
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
def get_valid_lables(
|
||||
*,
|
||||
message_id: Optional[UUID] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
) -> ValidLabelsResponse:
|
||||
if message_id:
|
||||
pr = PromptRepository(db, api_client=api_client)
|
||||
message = pr.fetch_message(message_id=message_id)
|
||||
if message.parent_id is None:
|
||||
valid_labels = settings.tree_manager.labels_initial_prompt
|
||||
elif message.role == "assistant":
|
||||
valid_labels = settings.tree_manager.labels_assistant_reply
|
||||
else:
|
||||
valid_labels = settings.tree_manager.labels_prompter_reply
|
||||
else:
|
||||
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
|
||||
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
for l in valid_labels
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel):
|
||||
num_required_rankings: int = 3
|
||||
"""Number of rankings in which the message participated."""
|
||||
|
||||
p_activate_backlog_tree: float = 0.8
|
||||
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
|
||||
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
|
||||
activations."""
|
||||
|
||||
min_active_rankings_per_lang: int = 2
|
||||
"""When the number of active ranking tasks is below this value when a tree enters a terminal
|
||||
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
|
||||
|
||||
labels_initial_prompt: list[TextLabel] = [
|
||||
TextLabel.spam,
|
||||
TextLabel.quality,
|
||||
@@ -170,9 +179,9 @@ class Settings(BaseSettings):
|
||||
|
||||
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
|
||||
|
||||
USER_STATS_INTERVAL_DAY: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
|
||||
USER_STATS_INTERVAL_DAY: int = 5 # minutes
|
||||
USER_STATS_INTERVAL_WEEK: int = 15 # minutes
|
||||
USER_STATS_INTERVAL_MONTH: int = 60 # minutes
|
||||
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
|
||||
|
||||
@validator(
|
||||
|
||||
@@ -57,6 +57,11 @@ class Message(SQLModel, table=True):
|
||||
|
||||
rank: Optional[int] = Field(nullable=True)
|
||||
|
||||
synthetic: Optional[bool] = Field(
|
||||
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
|
||||
)
|
||||
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ class State(str, Enum):
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
BACKLOG_RANKING = "backlog_ranking"
|
||||
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -51,6 +54,7 @@ VALID_STATES = (
|
||||
State.READY_FOR_SCORING,
|
||||
State.READY_FOR_EXPORT,
|
||||
State.ABORTED_LOW_GRADE,
|
||||
State.BACKLOG_RANKING,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
||||
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
|
||||
max_children_count: int = Field(nullable=False)
|
||||
state: str = Field(nullable=False, max_length=128, index=True)
|
||||
active: bool = Field(nullable=False, index=True)
|
||||
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
|
||||
|
||||
@@ -177,6 +177,7 @@ class PromptRepository:
|
||||
review_count: int = 0,
|
||||
review_result: bool = False,
|
||||
check_tree_state: bool = True,
|
||||
check_duplicate: bool = True,
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
@@ -199,7 +200,7 @@ class PromptRepository:
|
||||
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
|
||||
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
|
||||
|
||||
if self.check_users_recent_replies_for_duplicates(text):
|
||||
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
|
||||
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
|
||||
|
||||
if task.parent_message_id:
|
||||
@@ -480,6 +481,7 @@ class PromptRepository:
|
||||
task_id=task.id if task else None,
|
||||
)
|
||||
|
||||
message: Message = None
|
||||
if message_id:
|
||||
if not task:
|
||||
if text_labels.is_report is True:
|
||||
@@ -909,8 +911,7 @@ FROM (
|
||||
) AS cc
|
||||
WHERE message.id = cc.id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
|
||||
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
|
||||
@@ -25,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
from sqlmodel import Session, func, not_, or_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -73,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
role: str
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
message_tree_id: UUID
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
@@ -360,7 +361,7 @@ class TreeManager:
|
||||
random_reply_message = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message)
|
||||
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
conversation = prepare_conversation(messages)
|
||||
message = messages[-1]
|
||||
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
@@ -625,19 +626,28 @@ class TreeManager:
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
assert mts
|
||||
|
||||
is_terminal = state in message_tree_state.TERMINAL_STATES
|
||||
|
||||
was_active = mts.active
|
||||
if is_terminal:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.flush
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False)
|
||||
if root_msg and was_active:
|
||||
if random.random() < self.cfg.p_activate_backlog_tree:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
|
||||
if self.cfg.min_active_rankings_per_lang > 0:
|
||||
incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang)
|
||||
if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang:
|
||||
self.activate_backlog_tree(lang=root_msg.lang)
|
||||
else:
|
||||
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
|
||||
|
||||
@@ -680,24 +690,30 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False, None
|
||||
if mts.state != message_tree_state.State.SCORING_FAILED:
|
||||
if not mts.active or mts.state not in (
|
||||
message_tree_state.State.RANKING,
|
||||
message_tree_state.State.READY_FOR_SCORING,
|
||||
):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False, None
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
if (
|
||||
mts.state != message_tree_state.State.SCORING_FAILED
|
||||
and mts.state != message_tree_state.State.READY_FOR_SCORING
|
||||
):
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(message_tree_id, rankings_by_message)
|
||||
return True
|
||||
|
||||
@@ -759,8 +775,35 @@ class TreeManager:
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
|
||||
return True
|
||||
|
||||
def activate_backlog_tree(self, lang: str) -> MessageTreeState:
|
||||
while True:
|
||||
# find tree in backlog state
|
||||
backlog_tree: MessageTreeState = (
|
||||
self.db.query(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.id) # root msg
|
||||
.filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING)
|
||||
.filter(Message.lang == lang)
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not backlog_tree:
|
||||
return None
|
||||
|
||||
if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0:
|
||||
logger.info(
|
||||
f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state."
|
||||
)
|
||||
self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
else:
|
||||
logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}")
|
||||
backlog_tree.active = True
|
||||
self._enter_state(backlog_tree, message_tree_state.State.RANKING)
|
||||
return backlog_tree
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
@@ -828,7 +871,8 @@ class TreeManager:
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings,
|
||||
mts.message_tree_id
|
||||
FROM message_tree_state mts
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active -- only consider active trees
|
||||
@@ -837,7 +881,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.lang = :lang -- matches lang
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
GROUP BY m.parent_id, m.role, mts.message_tree_id
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
@@ -846,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
|
||||
SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings,
|
||||
ir.message_tree_id
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"""
|
||||
|
||||
@@ -985,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(
|
||||
@@ -1029,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
rankings = (
|
||||
self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all()
|
||||
self.db.query(MessageTreeState)
|
||||
.filter(
|
||||
or_(
|
||||
MessageTreeState.state == message_tree_state.State.RANKING,
|
||||
MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(rankings) > 0:
|
||||
logger.info(f"Checking state of {len(rankings)} message trees in ranking state.")
|
||||
@@ -1322,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def retry_scoring_failed_message_trees(self):
|
||||
query = self.db.query(MessageTreeState.message_tree_id).filter(
|
||||
query = self.db.query(MessageTreeState).filter(
|
||||
MessageTreeState.state == message_tree_state.State.SCORING_FAILED
|
||||
)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
for row in query.all():
|
||||
for mts in query.all():
|
||||
mts: MessageTreeState
|
||||
try:
|
||||
message_tree_id = row["message_tree_id"]
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message)
|
||||
if not self.check_condition_for_scoring_state(mts.message_tree_id):
|
||||
mts.active = True
|
||||
self._enter_state(message_tree_state.State.RANKING)
|
||||
except Exception:
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})")
|
||||
logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1366,8 +1418,8 @@ if __name__ == "__main__":
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
# )
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b"))
|
||||
)
|
||||
|
||||
# print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
|
||||
|
||||
@@ -117,13 +117,20 @@ class UserRepository:
|
||||
self.db.add(user)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
def _lookup_user_tx(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
auth_method: str,
|
||||
display_name: Optional[str] = None,
|
||||
create_missing: bool = True,
|
||||
) -> User | None:
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
User.username == username,
|
||||
User.auth_method == auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -131,30 +138,46 @@ class UserRepository:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
self.db.add(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
user.display_name = display_name
|
||||
self.db.add(user)
|
||||
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
return self._lookup_client_user_tx(client_user, create_missing)
|
||||
return self._lookup_user_tx(
|
||||
username=client_user.id,
|
||||
auth_method=client_user.auth_method,
|
||||
display_name=client_user.display_name,
|
||||
create_missing=create_missing,
|
||||
)
|
||||
except IntegrityError:
|
||||
# catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username
|
||||
if i + 1 == num_retries:
|
||||
raise
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None:
|
||||
return self._lookup_user_tx(
|
||||
username=username,
|
||||
auth_method="system",
|
||||
display_name=f"__system__/{username}",
|
||||
create_missing=create_missing,
|
||||
)
|
||||
|
||||
def query_users_ordered_by_username(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
|
||||
@@ -214,6 +214,10 @@ class UserStatsRepository:
|
||||
d = delete(UserStats).where(UserStats.time_frame == time_frame_key)
|
||||
self.session.execute(d)
|
||||
|
||||
if None in stats_by_user:
|
||||
logger.warning("Some messages in DB have NULL values in user_id column.")
|
||||
del stats_by_user[None]
|
||||
|
||||
# compute magic leader score
|
||||
for v in stats_by_user.values():
|
||||
v.leader_score = v.compute_leader_score()
|
||||
|
||||
@@ -12,12 +12,15 @@ from pydantic import BaseModel
|
||||
|
||||
class ExportMessageNode(BaseModel):
|
||||
message_id: str
|
||||
parent_id: Optional[str]
|
||||
text: Optional[str]
|
||||
parent_id: str | None
|
||||
text: str
|
||||
role: str
|
||||
review_count: Optional[int]
|
||||
rank: Optional[int]
|
||||
replies: Optional[list[ExportMessageNode]]
|
||||
lang: str | None
|
||||
review_count: int | None
|
||||
rank: int | None
|
||||
synthetic: bool | None
|
||||
model_name: str | None
|
||||
replies: list[ExportMessageNode] | None
|
||||
|
||||
@classmethod
|
||||
def prep_message_export(cls, message: Message) -> ExportMessageNode:
|
||||
@@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel):
|
||||
parent_id=str(message.parent_id) if message.parent_id else None,
|
||||
text=str(message.payload.payload.text),
|
||||
role=message.role,
|
||||
lang=message.lang,
|
||||
review_count=message.review_count,
|
||||
synthetic=message.synthetic,
|
||||
model_name=message.model_name,
|
||||
rank=message.rank,
|
||||
)
|
||||
|
||||
|
||||
class ExportMessageTree(BaseModel):
|
||||
message_tree_id: str
|
||||
replies: Optional[ExportMessageNode]
|
||||
prompt: Optional[ExportMessageNode]
|
||||
|
||||
|
||||
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
|
||||
|
||||
@@ -13,5 +13,6 @@ RUN pip install -e /oasst-shared
|
||||
COPY ./backend/alembic /app/alembic
|
||||
COPY ./backend/alembic.ini /app/alembic.ini
|
||||
COPY ./backend/main.py /app/main.py
|
||||
COPY ./backend/import.py /app/import.py
|
||||
COPY ./backend/oasst_backend /app/oasst_backend
|
||||
COPY ./backend/test_data /app/test_data
|
||||
|
||||
@@ -16,4 +16,5 @@ tmux split-window -h
|
||||
tmux send-keys "cd text-client" C-m
|
||||
tmux send-keys "sleep 5" C-m
|
||||
tmux send-keys "python __main__.py" C-m
|
||||
tmux select-layout even-horizontal
|
||||
tmux attach-session -t "inference-dev-setup"
|
||||
|
||||
@@ -258,7 +258,7 @@ class LabelConversationReplyTask(AbstractLabelTask):
|
||||
"""A task to label a reply to a conversation."""
|
||||
|
||||
type: Literal["label_conversation_reply"] = "label_conversation_reply"
|
||||
conversation: Conversation # the conversation so far
|
||||
conversation: Conversation # the conversation so far (new: including the reply message)
|
||||
reply_message: Optional[ConversationMessage]
|
||||
reply: str
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling assistant replies", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling initial prompts", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -11,6 +11,10 @@ describe("labeling prompter replies", () => {
|
||||
// For specific task pages the no task available result is normal.
|
||||
if (type === undefined) return;
|
||||
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
|
||||
@@ -44,6 +44,10 @@ describe("handles random tasks", () => {
|
||||
break;
|
||||
}
|
||||
case "label-task": {
|
||||
cy.get('[data-cy="label-question"]').each((label) => {
|
||||
// Click the no button, this generally approves the spam check
|
||||
cy.wrap(label).find('[data-cy="no"]').click();
|
||||
});
|
||||
cy.get('[data-cy="label-options"]').each((label) => {
|
||||
// Click the 4th option
|
||||
cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click();
|
||||
@@ -55,15 +59,6 @@ describe("handles random tasks", () => {
|
||||
|
||||
break;
|
||||
}
|
||||
case "spam-task": {
|
||||
cy.get('[data-cy="not-spam-button"]').click();
|
||||
|
||||
cy.get('[data-cy="review"]').click();
|
||||
|
||||
cy.get('[data-cy="submit"]').click();
|
||||
|
||||
break;
|
||||
}
|
||||
case undefined: {
|
||||
throw new Error("No tasks available, but at least create initial prompt expected");
|
||||
}
|
||||
|
||||
@@ -11,10 +11,12 @@
|
||||
"legal": "Legal",
|
||||
"loading": "Loading...",
|
||||
"more_information": "More Information",
|
||||
"no": "No",
|
||||
"privacy_policy": "Privacy Policy",
|
||||
"report_a_bug": "Report a Bug",
|
||||
"sign_in": "Sign In",
|
||||
"sign_out": "Sign Out",
|
||||
"terms_of_service": "Terms of Service",
|
||||
"title": "Open Assistant"
|
||||
"title": "Open Assistant",
|
||||
"yes": "Yes"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"label_highlighted_yes_no_instruction": "Answer the following question(s) about the highlighted message:",
|
||||
"label_highlighted_flag_instruction": "Select any that apply to the highlighted message:",
|
||||
"label_highlighted_likert_instruction": "Rate the highlighted message:",
|
||||
"label_message_yes_no_instruction": "Answer the following question(s) about the message:",
|
||||
"label_message_flag_instruction": "Select any that apply to the message:",
|
||||
"label_message_likert_instruction": "Rate the message:",
|
||||
"spam.question": "Is the message spam?",
|
||||
"fails_task.question": "Does the reply fail the prompter's task?",
|
||||
"not_appropriate": "Not Appropriate",
|
||||
"pii": "Contains PII",
|
||||
"hate_speech": "Hate Speech",
|
||||
"sexual_content": "Sexual Content",
|
||||
"moral_judgement": "Judges Morality",
|
||||
"political_content": "Political"
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
{
|
||||
"reactions": "Reactions",
|
||||
"label_action": "Label",
|
||||
"label_title": "Label",
|
||||
"submit_labels": "Submit",
|
||||
"message": "Message",
|
||||
"open_new_tab_action": "Open in new tab",
|
||||
"report_title": "Report",
|
||||
"parent": "Parent",
|
||||
"reactions": "Reactions",
|
||||
"report_action": "Report",
|
||||
"report_placeholder": "Why should this message be reviewed?",
|
||||
"send_report": "Send"
|
||||
"report_title": "Report",
|
||||
"send_report": "Send",
|
||||
"submit_labels": "Submit"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
{
|
||||
"write_initial_prompt": "Write your prompt here...",
|
||||
"default": {
|
||||
"unchanged_title": "No changes",
|
||||
"unchanged_message": "Are you sure you would like to continue?"
|
||||
@@ -12,18 +11,21 @@
|
||||
"label": "Create Initial Prompts",
|
||||
"desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
"overview": "Create an initial message to send to the assistant",
|
||||
"instruction": "Provide the initial prompts"
|
||||
"instruction": "Provide the initial prompts",
|
||||
"response_placeholder": "Write your prompt here..."
|
||||
},
|
||||
"reply_as_user": {
|
||||
"label": "Reply as User",
|
||||
"desc": "Chat with Open Assistant and help improve it's responses as you interact with it.",
|
||||
"overview": "Given the following conversation, provide an adequate reply",
|
||||
"instruction": "Provide the user's reply"
|
||||
"instruction": "Provide the user's reply",
|
||||
"response_placeholder": "Write your reply here..."
|
||||
},
|
||||
"reply_as_assistant": {
|
||||
"label": "Reply as Assistant",
|
||||
"desc": "Help Open Assistant improve its responses to conversations with other users.",
|
||||
"overview": "Given the following conversation, provide an adequate reply"
|
||||
"overview": "Given the following conversation, provide an adequate reply",
|
||||
"response_placeholder": "Write your reply here..."
|
||||
},
|
||||
"rank_user_replies": {
|
||||
"label": "Rank User Replies",
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Popover,
|
||||
PopoverAnchor,
|
||||
PopoverTrigger,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
import { LabelInputGroup } from "./Survey/LabelInputGroup";
|
||||
|
||||
interface Label {
|
||||
name: string;
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
}
|
||||
|
||||
interface FlaggableElementProps {
|
||||
children: React.ReactNode;
|
||||
message: Message;
|
||||
}
|
||||
|
||||
interface ValidLabelsResponse {
|
||||
valid_labels: Label[];
|
||||
}
|
||||
|
||||
export const FlaggableElement = (props: FlaggableElementProps) => {
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const { valid_labels } = response || { valid_labels: [] };
|
||||
const [values, setValues] = useState<number[]>([]);
|
||||
|
||||
const submittable =
|
||||
values.some((value) => {
|
||||
return value !== null;
|
||||
}) &&
|
||||
values.length === valid_labels.length &&
|
||||
valid_labels.length > 0;
|
||||
|
||||
const { trigger } = useSWRMutation("/api/set_label", post, {
|
||||
onSuccess: onClose,
|
||||
onError: onClose,
|
||||
});
|
||||
|
||||
const submitResponse = () => {
|
||||
const label_map: Map<string, number> = new Map();
|
||||
console.assert(valid_labels.length === values.length);
|
||||
values.forEach((value, idx) => {
|
||||
if (value !== null) {
|
||||
label_map.set(valid_labels[idx].name, value);
|
||||
}
|
||||
});
|
||||
trigger({
|
||||
message_id: props.message.id,
|
||||
label_map: Object.fromEntries(label_map),
|
||||
text: props.message.text,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onOpen={onOpen} onClose={onClose} closeOnBlur={false} isLazy lazyBehavior="keepMounted">
|
||||
<Box display="flex" alignItems="center" flexDirection={["column", "row"]} gap="2">
|
||||
<PopoverAnchor>{props.children}</PopoverAnchor>
|
||||
|
||||
<Tooltip label="Report" bg="red.500" aria-label="A tooltip">
|
||||
<Box>
|
||||
<PopoverTrigger>
|
||||
<Box as="button" display="flex" alignItems="center" justifyContent="center" borderRadius="full" p="1">
|
||||
<AlertCircle size="20" className="text-red-400" aria-hidden="true" />
|
||||
</Box>
|
||||
</PopoverTrigger>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
|
||||
<Modal isOpen={isOpen} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Select one or more labels that apply.</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<LabelInputGroup labelIDs={valid_labels.map(({ name }) => name)} onChange={setValues} />
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button
|
||||
isDisabled={!submittable}
|
||||
onClick={submitResponse}
|
||||
className={`bg-indigo-600 text-${useColorModeValue(
|
||||
colors.light.text,
|
||||
colors.dark.text
|
||||
)} hover:bg-indigo-700`}
|
||||
>
|
||||
Report
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
@@ -1,25 +1,7 @@
|
||||
import { Box, forwardRef, Grid, useColorMode } from "@chakra-ui/react";
|
||||
import { Box, forwardRef, useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
import { Message } from "src/types/Conversation";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
interface MessagesProps {
|
||||
messages: Message[];
|
||||
}
|
||||
|
||||
export const Messages = ({ messages }: MessagesProps) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
return (
|
||||
<FlaggableElement message={messageProps} key={i + messageProps.id}>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = forwardRef<Partial<Message>, "div">((message: Partial<Message>, ref) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import { Button, Flex } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
|
||||
interface LabelFlagGroupProps {
|
||||
values: number[];
|
||||
labelNames: string[];
|
||||
isEditable?: boolean;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange }: LabelFlagGroupProps) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
return (
|
||||
<Flex wrap="wrap" gap="4">
|
||||
{labelNames.map((name, idx) => (
|
||||
<Button
|
||||
key={name}
|
||||
onClick={() => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = newValues[idx] ? 0 : 1;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={values[idx] === 1 ? "blue" : undefined}
|
||||
>
|
||||
{t(getTypeSafei18nKey(name))}
|
||||
</Button>
|
||||
))}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,84 @@
|
||||
import { Text, VStack } from "@chakra-ui/react";
|
||||
import { Label } from "src/types/Tasks";
|
||||
|
||||
import { LabelLikertGroup } from "../Survey/LabelLikertGroup";
|
||||
import { LabelFlagGroup } from "./LabelFlagGroup";
|
||||
import { LabelYesNoGroup } from "./LabelYesNoGroup";
|
||||
|
||||
export interface LabelInputInstructions {
|
||||
yesNoInstruction: string;
|
||||
flagInstruction: string;
|
||||
likertInstruction: string;
|
||||
}
|
||||
|
||||
interface LabelInputGroupProps {
|
||||
values: number[];
|
||||
labels: Label[];
|
||||
requiredLabels?: string[];
|
||||
isEditable?: boolean;
|
||||
instructions: LabelInputInstructions;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelInputGroup = ({
|
||||
labels,
|
||||
values,
|
||||
requiredLabels,
|
||||
isEditable,
|
||||
instructions,
|
||||
onChange,
|
||||
}: LabelInputGroupProps) => {
|
||||
const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null);
|
||||
const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null);
|
||||
const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null);
|
||||
|
||||
return (
|
||||
<VStack alignItems="stretch" spacing={6}>
|
||||
{yesNoIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.yesNoInstruction}</Text>
|
||||
<LabelYesNoGroup
|
||||
values={yesNoIndexes.map((idx) => values[idx])}
|
||||
labelNames={yesNoIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
requiredLabels={requiredLabels}
|
||||
onChange={(yesNoValues) => {
|
||||
const newValues = values.slice();
|
||||
yesNoIndexes.forEach((idx, yesNoIndex) => (newValues[idx] = yesNoValues[yesNoIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
{flagIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.flagInstruction}</Text>
|
||||
<LabelFlagGroup
|
||||
values={flagIndexes.map((idx) => values[idx])}
|
||||
labelNames={flagIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
onChange={(flagValues) => {
|
||||
const newValues = values.slice();
|
||||
flagIndexes.forEach((idx, flagIndex) => (newValues[idx] = flagValues[flagIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
{likertIndexes.length > 0 && (
|
||||
<VStack alignItems="stretch" spacing={2}>
|
||||
<Text>{instructions.likertInstruction}</Text>
|
||||
<LabelLikertGroup
|
||||
labelIDs={likertIndexes.map((idx) => labels[idx].name)}
|
||||
isEditable={isEditable}
|
||||
onChange={(likertValues) => {
|
||||
const newValues = values.slice();
|
||||
likertIndexes.forEach((idx, likertIndex) => (newValues[idx] = likertValues[likertIndex]));
|
||||
onChange(newValues);
|
||||
}}
|
||||
/>
|
||||
</VStack>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
@@ -9,9 +9,10 @@ import {
|
||||
ModalOverlay,
|
||||
} from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useState } from "react";
|
||||
import { LabelInputGroup } from "src/components/Survey/LabelInputGroup";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
|
||||
import { get, post } from "src/lib/api";
|
||||
import { Label } from "src/types/Tasks";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
@@ -21,21 +22,19 @@ interface LabelMessagePopupProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
interface Label {
|
||||
name: string;
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
}
|
||||
|
||||
interface ValidLabelsResponse {
|
||||
valid_labels: Label[];
|
||||
}
|
||||
|
||||
export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => {
|
||||
const { t } = useTranslation("message");
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>("/api/valid_labels", get);
|
||||
const { t } = useTranslation();
|
||||
const { data: response } = useSWRImmutable<ValidLabelsResponse>(`/api/valid_labels?message_id=${messageId}`, get);
|
||||
const valid_labels = response?.valid_labels ?? [];
|
||||
const [values, setValues] = useState<number[]>(null);
|
||||
const [values, setValues] = useState<number[]>(new Array(valid_labels.length).fill(null));
|
||||
|
||||
useEffect(() => {
|
||||
setValues(new Array(valid_labels.length).fill(null));
|
||||
}, [messageId, valid_labels.length]);
|
||||
|
||||
const { trigger: setLabels } = useSWRMutation("/api/set_label", post);
|
||||
|
||||
@@ -60,14 +59,23 @@ export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopu
|
||||
<Modal isOpen={show} onClose={onClose}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t("label_title")}</ModalHeader>
|
||||
<ModalHeader>{t("message:label_title")}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<LabelInputGroup labelIDs={valid_labels.map(({ name }) => name)} onChange={setValues} />
|
||||
<LabelInputGroup
|
||||
labels={valid_labels}
|
||||
values={values}
|
||||
instructions={{
|
||||
yesNoInstruction: t("labelling:label_message_yes_no_instruction"),
|
||||
flagInstruction: t("labelling:label_message_flag_instruction"),
|
||||
likertInstruction: t("labelling:label_message_likert_instruction"),
|
||||
}}
|
||||
onChange={setValues}
|
||||
/>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button colorScheme="blue" mr={3} onClick={submit}>
|
||||
{t("submit_labels")}
|
||||
{t("message:submit_labels")}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import { Button, HStack, Text, Tooltip } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
|
||||
interface LabelYesNoGroupProps {
|
||||
values: number[];
|
||||
labelNames: string[];
|
||||
requiredLabels?: string[];
|
||||
isEditable?: boolean;
|
||||
onChange: (values: number[]) => void;
|
||||
}
|
||||
|
||||
export const LabelYesNoGroup = ({
|
||||
values,
|
||||
labelNames,
|
||||
requiredLabels = [],
|
||||
isEditable = true,
|
||||
onChange,
|
||||
}: LabelYesNoGroupProps) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
return (
|
||||
<>
|
||||
{labelNames.map((name, idx) => {
|
||||
return (
|
||||
<YesNoQuestion
|
||||
key={name}
|
||||
question={t(getTypeSafei18nKey(`${name}.question`))}
|
||||
value={values[idx] === null ? null : values[idx] > 0.1 ? true : false}
|
||||
onChange={(value) => {
|
||||
const newValues = values.slice();
|
||||
newValues[idx] = value;
|
||||
onChange(newValues);
|
||||
}}
|
||||
isEditable={isEditable}
|
||||
isRequired={requiredLabels.includes(name)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const YesNoQuestion = ({
|
||||
isEditable,
|
||||
question,
|
||||
value,
|
||||
isRequired,
|
||||
onChange,
|
||||
}: {
|
||||
isEditable: boolean;
|
||||
question: string;
|
||||
value: boolean;
|
||||
isRequired?: boolean;
|
||||
onChange: (boolean) => void;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div data-cy="label-question" style={{ maxWidth: "30em" }}>
|
||||
<Text display="inline">
|
||||
{question}
|
||||
{isRequired ? <RequiredMark /> : undefined}
|
||||
</Text>
|
||||
<HStack style={{ float: "right" }}>
|
||||
<Button
|
||||
data-cy="yes"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === true ? "blue" : undefined}
|
||||
onClick={() => onChange(isRequired ? true : value === null ? true : null)}
|
||||
>
|
||||
{t("yes")}
|
||||
</Button>
|
||||
<Button
|
||||
data-cy="no"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === false ? "blue" : undefined}
|
||||
onClick={() => onChange(isRequired ? false : value === null ? false : null)}
|
||||
>
|
||||
{t("no")}
|
||||
</Button>
|
||||
</HStack>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const RequiredMark = () => (
|
||||
<Tooltip label="Required">
|
||||
<span style={{ color: "red" }}>*</span>
|
||||
</Tooltip>
|
||||
);
|
||||
+3
-3
@@ -135,9 +135,9 @@ const getLabelInfo = (label: string): LabelInfo => {
|
||||
oneDescription: ["Contains text which is incorrect or misleading"],
|
||||
inverted: true,
|
||||
};
|
||||
case "helpful":
|
||||
case "helpfulness":
|
||||
return {
|
||||
zeroText: "Unhelful",
|
||||
zeroText: "Unhelpful",
|
||||
zeroDescription: [],
|
||||
oneText: "Helpful",
|
||||
oneDescription: ["Completes the task to a high standard"],
|
||||
@@ -186,7 +186,7 @@ const getLabelInfo = (label: string): LabelInfo => {
|
||||
}
|
||||
};
|
||||
|
||||
export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
|
||||
export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => {
|
||||
const [labelValues, setLabelValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => null));
|
||||
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
@@ -4,12 +4,10 @@ import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
import { TaskStatus } from "src/components/Tasks/Task";
|
||||
import { BaseTask } from "src/types/Task";
|
||||
|
||||
export interface TaskControlsProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
task: any;
|
||||
className?: string;
|
||||
task: BaseTask;
|
||||
taskStatus: TaskStatus;
|
||||
onEdit: () => void;
|
||||
onReview: () => void;
|
||||
@@ -17,7 +15,7 @@ export interface TaskControlsProps {
|
||||
onSkip: (reason: string) => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
export const TaskControls = ({ task, taskStatus, onEdit, onReview, onSubmit, onSkip }: TaskControlsProps) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
return (
|
||||
@@ -31,38 +29,32 @@ export const TaskControls = (props: TaskControlsProps) => {
|
||||
shadow="base"
|
||||
gap="4"
|
||||
>
|
||||
<TaskInfo id={props.task.id} output="Submit your answer" />
|
||||
<TaskInfo id={task.id} output="Submit your answer" />
|
||||
<Flex width={["full", "fit-content"]} justify="center" ml="auto" gap={2}>
|
||||
{props.taskStatus === "REVIEW" || props.taskStatus === "SUBMITTED" ? (
|
||||
{taskStatus.mode === "EDIT" ? (
|
||||
<>
|
||||
<Tooltip label="Edit">
|
||||
<IconButton
|
||||
size="lg"
|
||||
data-cy="edit"
|
||||
aria-label="edit"
|
||||
onClick={props.onEdit}
|
||||
icon={<Edit2 size="1em" />}
|
||||
/>
|
||||
</Tooltip>
|
||||
<SkipButton onSkip={onSkip} />
|
||||
<SubmitButton
|
||||
colorScheme="green"
|
||||
data-cy="submit"
|
||||
isDisabled={props.taskStatus === "SUBMITTED"}
|
||||
onClick={props.onSubmit}
|
||||
colorScheme="blue"
|
||||
data-cy="review"
|
||||
isDisabled={taskStatus.replyValidity === "INVALID"}
|
||||
onClick={onReview}
|
||||
>
|
||||
Submit
|
||||
Review
|
||||
</SubmitButton>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<SkipButton onSkip={props.onSkip} />
|
||||
<Tooltip label="Edit">
|
||||
<IconButton size="lg" data-cy="edit" aria-label="edit" onClick={onEdit} icon={<Edit2 size="1em" />} />
|
||||
</Tooltip>
|
||||
<SubmitButton
|
||||
colorScheme="blue"
|
||||
data-cy="review"
|
||||
isDisabled={props.taskStatus === "NOT_SUBMITTABLE"}
|
||||
onClick={props.onReview}
|
||||
colorScheme="green"
|
||||
data-cy="submit"
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onClick={onSubmit}
|
||||
>
|
||||
Review
|
||||
Submit
|
||||
</SubmitButton>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -7,6 +7,8 @@ import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks";
|
||||
|
||||
export const CreateTask = ({
|
||||
task,
|
||||
@@ -15,7 +17,7 @@ export const CreateTask = ({
|
||||
isDisabled,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ text: string }>) => {
|
||||
}: TaskSurveyProps<CreateInitialPromptTask | CreateAssistantReplyTask | CreatePrompterReplyTask, { text: string }>) => {
|
||||
const { t, i18n } = useTranslation(["tasks", "common"]);
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
@@ -39,7 +41,7 @@ export const CreateTask = ({
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<TaskHeader taskType={taskType} />
|
||||
{!!task.conversation && (
|
||||
{task.type !== TaskType.initial_prompt && (
|
||||
<Box mt="4" borderRadius="lg" bg={cardColor} className="p-3 sm:p-6">
|
||||
<MessageTable messages={task.conversation.messages} highlightLastMessage />
|
||||
</Box>
|
||||
@@ -56,7 +58,11 @@ export const CreateTask = ({
|
||||
text={inputText}
|
||||
onTextChange={textChangeHandler}
|
||||
thresholds={{ low: 20, medium: 40, goal: 50 }}
|
||||
textareaProps={{ placeholder: t("tasks:write_initial_prompt"), isDisabled, isReadOnly: !isEditable }}
|
||||
textareaProps={{
|
||||
placeholder: t(getTypeSafei18nKey(`tasks:${taskType.id}.response_placeholder`)),
|
||||
isDisabled,
|
||||
isReadOnly: !isEditable,
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
</>
|
||||
|
||||
@@ -5,6 +5,8 @@ import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks";
|
||||
|
||||
export const EvaluateTask = ({
|
||||
task,
|
||||
@@ -13,19 +15,25 @@ export const EvaluateTask = ({
|
||||
isDisabled,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ ranking: number[] }>) => {
|
||||
}: TaskSurveyProps<
|
||||
RankInitialPromptsTask | RankAssistantRepliesTask | RankPrompterRepliesTask,
|
||||
{ ranking: number[] }
|
||||
>) => {
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const [ranking, setRanking] = useState<number[]>(null);
|
||||
|
||||
let messages = [];
|
||||
if (task.conversation) {
|
||||
if (task.type !== TaskType.rank_initial_prompts) {
|
||||
messages = task.conversation.messages;
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (ranking === null) {
|
||||
const defaultRanking = (task.replies ?? task.prompts).map((_, idx) => idx);
|
||||
onReplyChanged({ ranking: defaultRanking });
|
||||
if (task.type === TaskType.rank_initial_prompts) {
|
||||
onReplyChanged({ ranking: task.prompts.map((_, idx) => idx) });
|
||||
} else {
|
||||
onReplyChanged({ ranking: task.replies.map((_, idx) => idx) });
|
||||
}
|
||||
onValidityChanged("DEFAULT");
|
||||
} else {
|
||||
onReplyChanged({ ranking });
|
||||
@@ -33,7 +41,7 @@ export const EvaluateTask = ({
|
||||
}
|
||||
}, [task, ranking, onReplyChanged, onValidityChanged]);
|
||||
|
||||
const sortables = task.replies ? "replies" : "prompts";
|
||||
const sortables = task.type === TaskType.rank_initial_prompts ? "prompts" : "replies";
|
||||
|
||||
return (
|
||||
<div data-cy="task" data-task-type="evaluate-task">
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import { Box, Button, Flex, HStack, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, useBoolean, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useEffect, useState } from "react";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { LabelInputGroup } from "src/components/Messages/LabelInputGroup";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { LabelInputGroup } from "src/components/Survey/LabelInputGroup";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { TaskType } from "src/types/Task";
|
||||
import { LabelTaskType } from "src/types/Tasks";
|
||||
|
||||
const isRequired = (labelName: string, requiredLabels?: string[]) => {
|
||||
return requiredLabels ? requiredLabels.includes(labelName) : false;
|
||||
};
|
||||
|
||||
export const LabelTask = ({
|
||||
task,
|
||||
@@ -13,15 +20,33 @@ export const LabelTask = ({
|
||||
isEditable,
|
||||
onReplyChanged,
|
||||
onValidityChanged,
|
||||
}: TaskSurveyProps<{ text: string; labels: Record<string, number>; message_id: string }>) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(new Array(task.valid_labels.length).fill(null));
|
||||
}: TaskSurveyProps<LabelTaskType, { text: string; labels: Record<string, number>; message_id: string }>) => {
|
||||
const { t } = useTranslation("labelling");
|
||||
const [values, setValues] = useState<number[]>(new Array(task.labels.length).fill(null));
|
||||
const [userInputMade, setUserInputMade] = useBoolean(false);
|
||||
|
||||
// Initial setup to run when the task changes
|
||||
useEffect(() => {
|
||||
console.assert(task.valid_labels.length === sliderValues.length);
|
||||
const labels = Object.fromEntries(task.valid_labels.map((label, i) => [label, sliderValues[i]]));
|
||||
onReplyChanged({ labels, text: task.reply || task.prompt, message_id: task.message_id });
|
||||
onValidityChanged(sliderValues.every((value) => value !== null) ? "VALID" : "INVALID");
|
||||
}, [task, sliderValues, onReplyChanged, onValidityChanged]);
|
||||
setValues(new Array(task.labels.length).fill(null));
|
||||
onValidityChanged(task.labels.some(({ name }) => isRequired(name, task.mandatory_labels)) ? "INVALID" : "DEFAULT");
|
||||
setUserInputMade.off();
|
||||
}, [task, setUserInputMade, onValidityChanged]);
|
||||
|
||||
// Update the reply and validity when the values change
|
||||
useEffect(() => {
|
||||
onReplyChanged({
|
||||
text: "unused?",
|
||||
labels: Object.fromEntries(task.labels.map(({ name }, idx) => [name, values[idx] || 0])),
|
||||
message_id: task.message_id,
|
||||
});
|
||||
onValidityChanged(
|
||||
task.labels.some(({ name }, idx) => values[idx] === null && isRequired(name, task.mandatory_labels))
|
||||
? "INVALID"
|
||||
: userInputMade
|
||||
? "VALID"
|
||||
: "DEFAULT"
|
||||
);
|
||||
}, [task, values, onReplyChanged, userInputMade, onValidityChanged]);
|
||||
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam";
|
||||
@@ -31,12 +56,9 @@ export const LabelTask = ({
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<TaskHeader taskType={taskType} />
|
||||
{task.conversation ? (
|
||||
{task.type !== TaskType.label_initial_prompt ? (
|
||||
<Box mt="4" p={[4, 6]} borderRadius="lg" bg={cardColor}>
|
||||
<MessageTable
|
||||
messages={[...(task.conversation?.messages ?? []), task.reply_message]}
|
||||
highlightLastMessage
|
||||
/>
|
||||
<MessageTable messages={task.conversation.messages} highlightLastMessage />
|
||||
</Box>
|
||||
) : (
|
||||
<Box mt="4">
|
||||
@@ -44,51 +66,22 @@ export const LabelTask = ({
|
||||
</Box>
|
||||
)}
|
||||
</>
|
||||
{isSpamTask ? (
|
||||
<SpamTaskInput
|
||||
value={sliderValues[0]}
|
||||
onChange={(value) => setSliderValues([value])}
|
||||
isEditable={isEditable}
|
||||
/>
|
||||
) : (
|
||||
<Flex direction="column" alignItems="stretch">
|
||||
<Text>The highlighted message:</Text>
|
||||
<LabelInputGroup labelIDs={task.valid_labels} isEditable={isEditable} onChange={setSliderValues} />
|
||||
</Flex>
|
||||
)}
|
||||
<LabelInputGroup
|
||||
labels={task.labels}
|
||||
values={values}
|
||||
requiredLabels={task.mandatory_labels}
|
||||
isEditable={isEditable}
|
||||
instructions={{
|
||||
yesNoInstruction: t("label_highlighted_yes_no_instruction"),
|
||||
flagInstruction: t("label_highlighted_flag_instruction"),
|
||||
likertInstruction: t("label_highlighted_likert_instruction"),
|
||||
}}
|
||||
onChange={(values) => {
|
||||
setValues(values);
|
||||
setUserInputMade.on();
|
||||
}}
|
||||
/>
|
||||
</TwoColumnsWithCards>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const SpamTaskInput = ({
|
||||
isEditable,
|
||||
value,
|
||||
onChange,
|
||||
}: {
|
||||
isEditable: boolean;
|
||||
value: number;
|
||||
onChange: (number) => void;
|
||||
}) => {
|
||||
return (
|
||||
<HStack>
|
||||
<Text>Is the highlighted message spam?</Text>
|
||||
<Button
|
||||
data-cy="spam-button"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === 1 ? "blue" : undefined}
|
||||
onClick={() => onChange(1)}
|
||||
>
|
||||
Yes
|
||||
</Button>
|
||||
<Button
|
||||
data-cy="not-spam-button"
|
||||
isDisabled={!isEditable}
|
||||
colorScheme={value === 0 ? "blue" : undefined}
|
||||
onClick={() => onChange(0)}
|
||||
>
|
||||
No
|
||||
</Button>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useReducer } from "react";
|
||||
import { useMemo, useRef } from "react";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { CreateTask } from "src/components/Tasks/CreateTask";
|
||||
import { EvaluateTask } from "src/components/Tasks/EvaluateTask";
|
||||
@@ -8,15 +9,52 @@ import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskType
|
||||
import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning";
|
||||
import { post } from "src/lib/api";
|
||||
import { getTypeSafei18nKey } from "src/lib/i18n";
|
||||
import { TaskContent, TaskReplyValidity } from "src/types/Task";
|
||||
import { BaseTask, TaskContent, TaskReplyValidity } from "src/types/Task";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "VALID" | "REVIEW" | "SUBMITTED";
|
||||
interface EditMode {
|
||||
mode: "EDIT";
|
||||
replyValidity: TaskReplyValidity;
|
||||
}
|
||||
interface ReviewMode {
|
||||
mode: "REVIEW";
|
||||
}
|
||||
interface DefaultWarnMode {
|
||||
mode: "DEFAULT_WARN";
|
||||
}
|
||||
interface SubmittedMode {
|
||||
mode: "SUBMITTED";
|
||||
}
|
||||
|
||||
export interface TaskSurveyProps<T> {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
task: any;
|
||||
export type TaskStatus = EditMode | DefaultWarnMode | ReviewMode | SubmittedMode;
|
||||
|
||||
interface NewTask {
|
||||
action: "NEW_TASK";
|
||||
}
|
||||
|
||||
interface Review {
|
||||
action: "REVIEW";
|
||||
}
|
||||
|
||||
interface SetSubmitted {
|
||||
action: "SET_SUBMITTED";
|
||||
}
|
||||
|
||||
interface ReturnToEdit {
|
||||
action: "RETURN_EDIT";
|
||||
}
|
||||
|
||||
interface AcceptDefault {
|
||||
action: "ACCEPT_DEFAULT";
|
||||
}
|
||||
|
||||
interface UpdateValidity {
|
||||
action: "UPDATE_VALIDITY";
|
||||
replyValidity: TaskReplyValidity;
|
||||
}
|
||||
|
||||
export interface TaskSurveyProps<TaskType extends BaseTask, T> {
|
||||
task: TaskType;
|
||||
taskType: TaskInfo;
|
||||
isEditable: boolean;
|
||||
isDisabled?: boolean;
|
||||
@@ -26,13 +64,63 @@ export interface TaskSurveyProps<T> {
|
||||
|
||||
export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
const { t } = useTranslation("tasks");
|
||||
const [taskStatus, setTaskStatus] = useState<TaskStatus>("NOT_SUBMITTABLE");
|
||||
const [taskStatus, taskEvent] = useReducer(
|
||||
(
|
||||
status: TaskStatus,
|
||||
event: NewTask | UpdateValidity | AcceptDefault | Review | ReturnToEdit | SetSubmitted
|
||||
): TaskStatus => {
|
||||
switch (event.action) {
|
||||
case "NEW_TASK":
|
||||
return { mode: "EDIT", replyValidity: "INVALID" };
|
||||
case "UPDATE_VALIDITY":
|
||||
return status.mode === "EDIT" ? { mode: "EDIT", replyValidity: event.replyValidity } : status;
|
||||
case "ACCEPT_DEFAULT":
|
||||
return status.mode === "DEFAULT_WARN" ? { mode: "REVIEW" } : status;
|
||||
case "REVIEW": {
|
||||
if (status.mode === "EDIT") {
|
||||
switch (status.replyValidity) {
|
||||
case "DEFAULT":
|
||||
return { mode: "DEFAULT_WARN" };
|
||||
case "VALID":
|
||||
return { mode: "REVIEW" };
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
case "RETURN_EDIT": {
|
||||
switch (status.mode) {
|
||||
case "REVIEW":
|
||||
return { mode: "EDIT", replyValidity: "VALID" };
|
||||
case "DEFAULT_WARN":
|
||||
return { mode: "EDIT", replyValidity: "DEFAULT" };
|
||||
default:
|
||||
return status;
|
||||
}
|
||||
}
|
||||
case "SET_SUBMITTED": {
|
||||
return status.mode === "REVIEW" ? { mode: "SUBMITTED" } : status;
|
||||
}
|
||||
}
|
||||
},
|
||||
{ mode: "EDIT", replyValidity: "INVALID" }
|
||||
);
|
||||
|
||||
const replyContent = useRef<TaskContent>(null);
|
||||
const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
|
||||
const updateValidity = useCallback(
|
||||
(replyValidity: TaskReplyValidity) => taskEvent({ action: "UPDATE_VALIDITY", replyValidity }),
|
||||
[taskEvent]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
taskEvent({ action: "NEW_TASK" });
|
||||
}, [task.id, updateValidity]);
|
||||
|
||||
const rootEl = useRef<HTMLDivElement>(null);
|
||||
|
||||
const taskType = TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode);
|
||||
const taskType = useMemo(
|
||||
() => TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode),
|
||||
[task.type, task.mode]
|
||||
);
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, {
|
||||
onSuccess: async () => {
|
||||
@@ -47,79 +135,36 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
});
|
||||
};
|
||||
|
||||
const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID";
|
||||
const submitted = taskStatus === "SUBMITTED";
|
||||
|
||||
const onValidityChanged = (validity: TaskReplyValidity) => {
|
||||
if (!edit_mode) return;
|
||||
switch (validity) {
|
||||
case "DEFAULT":
|
||||
if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT");
|
||||
break;
|
||||
case "VALID":
|
||||
if (taskStatus !== "VALID") setTaskStatus("VALID");
|
||||
break;
|
||||
case "INVALID":
|
||||
if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
const onReplyChanged = (content: TaskContent) => {
|
||||
replyContent.current = content;
|
||||
};
|
||||
|
||||
const reviewResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "DEFAULT":
|
||||
setShowUnchangedWarning(true);
|
||||
break;
|
||||
case "VALID":
|
||||
setTaskStatus("REVIEW");
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
const editResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "REVIEW":
|
||||
setTaskStatus("VALID");
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
};
|
||||
const onReplyChanged = useCallback(
|
||||
(content: TaskContent) => {
|
||||
replyContent.current = content;
|
||||
},
|
||||
[replyContent]
|
||||
);
|
||||
|
||||
const submitResponse = () => {
|
||||
switch (taskStatus) {
|
||||
case "REVIEW": {
|
||||
trigger({
|
||||
id: frontendId,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
setTaskStatus("SUBMITTED");
|
||||
scrollToTop(rootEl.current);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return;
|
||||
if (taskStatus.mode === "REVIEW") {
|
||||
trigger({
|
||||
id: frontendId,
|
||||
update_type: taskType.update_type,
|
||||
content: replyContent.current,
|
||||
});
|
||||
taskEvent({ action: "SET_SUBMITTED" });
|
||||
scrollToTop(rootEl.current);
|
||||
}
|
||||
};
|
||||
|
||||
function taskTypeComponent() {
|
||||
const taskTypeComponent = useMemo(() => {
|
||||
switch (taskType.category) {
|
||||
case TaskCategory.Create:
|
||||
return (
|
||||
<CreateTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Evaluate:
|
||||
@@ -127,10 +172,10 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
<EvaluateTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Label:
|
||||
@@ -138,37 +183,34 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
<LabelTask
|
||||
task={task}
|
||||
taskType={taskType}
|
||||
isEditable={edit_mode}
|
||||
isDisabled={submitted}
|
||||
isEditable={taskStatus.mode === "EDIT"}
|
||||
isDisabled={taskStatus.mode === "SUBMITTED"}
|
||||
onReplyChanged={onReplyChanged}
|
||||
onValidityChanged={onValidityChanged}
|
||||
onValidityChanged={updateValidity}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
}, [task, taskType, taskStatus.mode, onReplyChanged, updateValidity]);
|
||||
|
||||
return (
|
||||
<div ref={rootEl}>
|
||||
{taskTypeComponent()}
|
||||
{taskTypeComponent}
|
||||
<TaskControls
|
||||
task={task}
|
||||
taskStatus={taskStatus}
|
||||
onEdit={editResponse}
|
||||
onReview={reviewResponse}
|
||||
onEdit={() => taskEvent({ action: "RETURN_EDIT" })}
|
||||
onReview={() => taskEvent({ action: "REVIEW" })}
|
||||
onSubmit={submitResponse}
|
||||
onSkip={rejectTask}
|
||||
/>
|
||||
<UnchangedWarning
|
||||
show={showUnchangedWarning}
|
||||
show={taskStatus.mode === "DEFAULT_WARN"}
|
||||
title={t(getTypeSafei18nKey(`${taskType.id}.unchanged_title`)) || t("default.unchanged_title")}
|
||||
message={t(getTypeSafei18nKey(`${taskType.id}.unchanged_message`)) || t("default.unchanged_message")}
|
||||
continueButtonText={"Continue anyway"}
|
||||
onClose={() => setShowUnchangedWarning(false)}
|
||||
onClose={() => taskEvent({ action: "RETURN_EDIT" })}
|
||||
onContinueAnyway={() => {
|
||||
if (taskStatus === "DEFAULT") {
|
||||
setTaskStatus("REVIEW");
|
||||
setShowUnchangedWarning(false);
|
||||
}
|
||||
taskEvent({ action: "ACCEPT_DEFAULT" });
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -160,8 +160,8 @@ export class OasstApiClient {
|
||||
/**
|
||||
* Returns the valid labels for messages.
|
||||
*/
|
||||
async fetch_valid_text(): Promise<any> {
|
||||
return this.get(`/api/v1/text_labels/valid_labels`);
|
||||
async fetch_valid_text(messageId?: string): Promise<any> {
|
||||
return this.get("/api/v1/text_labels/valid_labels", { message_id: messageId });
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -5,8 +5,9 @@ import { createApiClient } from "src/lib/oasst_client_factory";
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { message_id } = req.query;
|
||||
const client = await createApiClient(token);
|
||||
const valid_labels = await client.fetch_valid_text();
|
||||
const valid_labels = await client.fetch_valid_text(message_id as string);
|
||||
res.status(200).json(valid_labels);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageLoading } from "src/components/Loading/MessageLoading";
|
||||
@@ -10,6 +11,7 @@ import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const MessageDetail = ({ id }: { id: string }) => {
|
||||
const { t } = useTranslation(["message", "common"]);
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
|
||||
const { isLoading: isLoadingParent, data: parent } = useSWRImmutable<Message>(`/api/messages/${id}/parent`, get);
|
||||
@@ -20,7 +22,7 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<title>{t("common:title")}</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
@@ -32,7 +34,7 @@ const MessageDetail = ({ id }: { id: string }) => {
|
||||
<>
|
||||
<Box pb="4">
|
||||
<Text fontWeight="bold" fontSize="xl" pb="2">
|
||||
Parent
|
||||
{t("parent")}
|
||||
</Text>
|
||||
<Box bg={backgroundColor} padding="4" borderRadius="xl" boxShadow="base" width="fit-content">
|
||||
<MessageTableEntry enabled message={parent} />
|
||||
@@ -54,7 +56,7 @@ MessageDetail.getLayout = (page) => getDashboardLayout(page);
|
||||
export const getServerSideProps = async ({ locale, query }) => ({
|
||||
props: {
|
||||
id: query.id,
|
||||
...(await serverSideTranslations(locale, ["common"])),
|
||||
...(await serverSideTranslations(locale, ["common", "message"])),
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
+21
-13
@@ -33,31 +33,39 @@ export interface RankPrompterRepliesTask extends BaseTask {
|
||||
replies: string[];
|
||||
}
|
||||
|
||||
export interface LabelAssistantReplyTask extends BaseTask {
|
||||
export interface Label {
|
||||
display_text: string;
|
||||
help_text: string;
|
||||
name: string;
|
||||
widget: "flag" | "yes_no" | "likert";
|
||||
}
|
||||
|
||||
export interface BaseLabelTask extends BaseTask {
|
||||
message_id: string;
|
||||
labels: Label[];
|
||||
valid_labels: string[];
|
||||
disposition: "spam" | "quality";
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelAssistantReplyTask extends BaseLabelTask {
|
||||
type: TaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply_message: Message;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelPrompterReplyTask extends BaseTask {
|
||||
export interface LabelPrompterReplyTask extends BaseLabelTask {
|
||||
type: TaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
conversation: Conversation;
|
||||
reply_message: Message;
|
||||
reply: string;
|
||||
valid_labels: string[];
|
||||
mode: "simple" | "full";
|
||||
mandatory_labels?: string[];
|
||||
}
|
||||
|
||||
export interface LabelInitialPromptTask extends BaseTask {
|
||||
export interface LabelInitialPromptTask extends BaseLabelTask {
|
||||
type: TaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelTaskType = LabelAssistantReplyTask | LabelPrompterReplyTask | LabelInitialPromptTask;
|
||||
|
||||
Vendored
+2
@@ -3,6 +3,7 @@ import type dashboard from "public/locales/en/dashboard.json";
|
||||
import type index from "public/locales/en/index.json";
|
||||
import type leaderboard from "public/locales/en/leaderboard.json";
|
||||
import type message from "public/locales/en/message.json";
|
||||
import type labelling from "public/locales/en/labelling.json";
|
||||
import type tasks from "public/locales/en/tasks.json";
|
||||
|
||||
declare module "i18next" {
|
||||
@@ -14,6 +15,7 @@ declare module "i18next" {
|
||||
leaderboard: typeof leaderboard;
|
||||
tasks: typeof tasks;
|
||||
message: typeof message;
|
||||
labelling: typeof labelling;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user